From 1bea8e1eacc68bcedbd3f207395bccf11011dae2 Mon Sep 17 00:00:00 2001 From: Juan Linietsky Date: Fri, 1 May 2020 09:34:23 -0300 Subject: New lightmapper -Added LocalVector (needed it) -Added stb_rect_pack (It's pretty cool, we could probably use it for other stuff too) -Fixes and changes all around the place -Added library for 128 bits fixed point (required for Delaunay3D) --- thirdparty/oidn/.gitignore | 1 + thirdparty/oidn/common/barrier.h | 52 + thirdparty/oidn/common/exception.h | 45 + thirdparty/oidn/common/platform.cpp | 114 + thirdparty/oidn/common/platform.h | 131 + thirdparty/oidn/common/ref.h | 163 + thirdparty/oidn/common/tensor.cpp | 83 + thirdparty/oidn/common/tensor.h | 66 + thirdparty/oidn/common/thread.cpp | 297 ++ thirdparty/oidn/common/thread.h | 202 + thirdparty/oidn/common/timer.h | 49 + thirdparty/oidn/core/api.cpp | 408 ++ thirdparty/oidn/core/autoencoder.cpp | 519 +++ thirdparty/oidn/core/autoencoder.h | 116 + thirdparty/oidn/core/buffer.h | 75 + thirdparty/oidn/core/common.h | 133 + thirdparty/oidn/core/device.cpp | 205 + thirdparty/oidn/core/device.h | 78 + thirdparty/oidn/core/filter.cpp | 27 + thirdparty/oidn/core/filter.h | 52 + thirdparty/oidn/core/image.h | 111 + thirdparty/oidn/core/input_reorder.h | 232 + thirdparty/oidn/core/math.h | 78 + thirdparty/oidn/core/network.cpp | 434 ++ thirdparty/oidn/core/network.h | 112 + thirdparty/oidn/core/node.h | 142 + thirdparty/oidn/core/output_reorder.h | 126 + thirdparty/oidn/core/transfer_function.cpp | 95 + thirdparty/oidn/core/transfer_function.h | 201 + thirdparty/oidn/core/upsample.h | 92 + thirdparty/oidn/core/weights_reorder.h | 99 + thirdparty/oidn/include/OpenImageDenoise/oidn.h | 214 + thirdparty/oidn/include/OpenImageDenoise/oidn.hpp | 468 ++ thirdparty/oidn/include/OpenImageDenoise/version.h | 23 + thirdparty/oidn/mkl-dnn/LICENSE | 214 + thirdparty/oidn/mkl-dnn/include/mkldnn.h | 1771 ++++++++ thirdparty/oidn/mkl-dnn/include/mkldnn.hpp | 2615 +++++++++++ thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h | 98 + thirdparty/oidn/mkl-dnn/include/mkldnn_types.h | 1415 ++++++ thirdparty/oidn/mkl-dnn/include/mkldnn_version.h | 32 + .../oidn/mkl-dnn/include/mkldnn_version.h.in | 32 + .../mkl-dnn/src/common/batch_normalization.cpp | 104 + .../mkl-dnn/src/common/batch_normalization_pd.hpp | 240 ++ thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp | 550 +++ thirdparty/oidn/mkl-dnn/src/common/concat.cpp | 86 + thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp | 211 + thirdparty/oidn/mkl-dnn/src/common/convolution.cpp | 200 + .../oidn/mkl-dnn/src/common/convolution_pd.cpp | 56 + .../oidn/mkl-dnn/src/common/convolution_pd.hpp | 348 ++ .../oidn/mkl-dnn/src/common/deconvolution.cpp | 188 + .../oidn/mkl-dnn/src/common/deconvolution_pd.hpp | 293 ++ thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp | 84 + thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp | 161 + thirdparty/oidn/mkl-dnn/src/common/engine.cpp | 75 + thirdparty/oidn/mkl-dnn/src/common/engine.hpp | 119 + .../oidn/mkl-dnn/src/common/inner_product.cpp | 106 + .../oidn/mkl-dnn/src/common/inner_product_pd.cpp | 56 + .../oidn/mkl-dnn/src/common/inner_product_pd.hpp | 321 ++ thirdparty/oidn/mkl-dnn/src/common/lrn.cpp | 91 + thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp | 170 + thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp | 280 ++ thirdparty/oidn/mkl-dnn/src/common/memory.cpp | 238 + thirdparty/oidn/mkl-dnn/src/common/memory.hpp | 63 + .../mkl-dnn/src/common/memory_desc_wrapper.cpp | 212 + .../mkl-dnn/src/common/memory_desc_wrapper.hpp | 400 ++ .../oidn/mkl-dnn/src/common/memory_tracking.hpp | 295 ++ .../oidn/mkl-dnn/src/common/mkldnn_debug.cpp | 131 + .../src/common/mkldnn_debug_autogenerated.cpp | 365 ++ .../oidn/mkl-dnn/src/common/mkldnn_thread.hpp | 115 + .../src/common/mkldnn_thread_parallel_nd.hpp | 277 ++ .../oidn/mkl-dnn/src/common/mkldnn_traits.hpp | 77 + thirdparty/oidn/mkl-dnn/src/common/nstl.hpp | 193 + thirdparty/oidn/mkl-dnn/src/common/pooling.cpp | 114 + thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp | 238 + thirdparty/oidn/mkl-dnn/src/common/primitive.cpp | 103 + thirdparty/oidn/mkl-dnn/src/common/primitive.hpp | 76 + .../oidn/mkl-dnn/src/common/primitive_attr.cpp | 290 ++ .../oidn/mkl-dnn/src/common/primitive_attr.hpp | 183 + .../oidn/mkl-dnn/src/common/primitive_desc.cpp | 78 + .../oidn/mkl-dnn/src/common/primitive_desc.hpp | 174 + .../mkl-dnn/src/common/primitive_exec_types.cpp | 90 + .../mkl-dnn/src/common/primitive_exec_types.hpp | 68 + .../oidn/mkl-dnn/src/common/primitive_iterator.cpp | 89 + .../oidn/mkl-dnn/src/common/primitive_iterator.hpp | 79 + thirdparty/oidn/mkl-dnn/src/common/query.cpp | 59 + thirdparty/oidn/mkl-dnn/src/common/reorder.cpp | 68 + thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp | 85 + thirdparty/oidn/mkl-dnn/src/common/rnn.cpp | 400 ++ thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp | 280 ++ thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp | 112 + thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp | 36 + thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp | 72 + thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp | 121 + thirdparty/oidn/mkl-dnn/src/common/softmax.cpp | 68 + thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp | 161 + thirdparty/oidn/mkl-dnn/src/common/stream.cpp | 46 + thirdparty/oidn/mkl-dnn/src/common/stream.hpp | 44 + thirdparty/oidn/mkl-dnn/src/common/sum.cpp | 79 + thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp | 143 + thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp | 200 + .../oidn/mkl-dnn/src/common/type_helpers.hpp | 348 ++ thirdparty/oidn/mkl-dnn/src/common/utils.cpp | 135 + thirdparty/oidn/mkl-dnn/src/common/utils.hpp | 370 ++ thirdparty/oidn/mkl-dnn/src/common/verbose.cpp | 665 +++ thirdparty/oidn/mkl-dnn/src/common/verbose.hpp | 62 + thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp | 46 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp | 112 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp | 60 + .../mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp | 40 + .../src/cpu/cpu_batch_normalization_utils.cpp | 140 + .../src/cpu/cpu_batch_normalization_utils.hpp | 43 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp | 51 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp | 41 + .../oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp | 74 + .../oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp | 46 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp | 45 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp | 324 ++ thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp | 70 + .../oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp | 84 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp | 151 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp | 42 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp | 277 ++ thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp | 89 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp | 40 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp | 83 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp | 544 +++ thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp | 334 ++ thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp | 262 ++ thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp | 48 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp | 41 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp | 45 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp | 48 + thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp | 39 + .../mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp | 372 ++ .../mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp | 72 + .../cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp | 2131 +++++++++ .../cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp | 36 + .../mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp | 2705 ++++++++++++ .../mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp | 37 + .../oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp | 346 ++ .../oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp | 36 + thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp | 280 ++ thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp | 58 + thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp | 86 + .../oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp | 206 + .../oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp | 28 + .../gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp | 1409 ++++++ .../gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp | 38 + .../s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp | 539 +++ .../s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp | 101 + .../gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp | 290 ++ .../jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp | 411 ++ .../jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp | 64 + .../s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp | 819 ++++ .../s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp | 2209 ++++++++++ .../s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp | 564 +++ .../s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp | 501 +++ .../jit_avx512_core_u8_copy_sum_an_kern.cpp | 1283 ++++++ .../jit_avx512_core_u8_copy_sum_at_kern.cpp | 3163 ++++++++++++++ .../jit_avx512_core_u8_copy_sum_bn_kern.cpp | 821 ++++ .../jit_avx512_core_u8_copy_sum_bt_kern.cpp | 647 +++ .../src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp | 116 + .../src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp | 38 + .../src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp | 180 + .../src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp | 37 + .../oidn/mkl-dnn/src/cpu/gemm_convolution.cpp | 307 ++ .../oidn/mkl-dnn/src/cpu/gemm_convolution.hpp | 250 ++ .../mkl-dnn/src/cpu/gemm_convolution_utils.cpp | 771 ++++ .../mkl-dnn/src/cpu/gemm_convolution_utils.hpp | 66 + .../oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp | 156 + .../oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp | 157 + .../mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp | 740 ++++ .../mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp | 266 ++ .../src/cpu/gemm_x8s8s32x_inner_product.cpp | 453 ++ .../src/cpu/gemm_x8s8s32x_inner_product.hpp | 166 + .../src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp | 674 +++ .../src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp | 110 + .../mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp | 545 +++ .../mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp | 344 ++ .../mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp | 1501 +++++++ .../mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp | 225 + .../oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp | 410 ++ .../oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp | 302 ++ .../src/cpu/jit_avx512_common_1x1_conv_kernel.cpp | 1255 ++++++ .../src/cpu/jit_avx512_common_1x1_conv_kernel.hpp | 108 + .../src/cpu/jit_avx512_common_1x1_convolution.cpp | 816 ++++ .../src/cpu/jit_avx512_common_1x1_convolution.hpp | 344 ++ .../src/cpu/jit_avx512_common_conv_kernel.cpp | 4539 ++++++++++++++++++++ .../src/cpu/jit_avx512_common_conv_kernel.hpp | 423 ++ .../jit_avx512_common_conv_winograd_kernel_f32.cpp | 1163 +++++ .../jit_avx512_common_conv_winograd_kernel_f32.hpp | 179 + .../src/cpu/jit_avx512_common_convolution.cpp | 1526 +++++++ .../src/cpu/jit_avx512_common_convolution.hpp | 302 ++ .../cpu/jit_avx512_common_convolution_winograd.cpp | 1215 ++++++ .../cpu/jit_avx512_common_convolution_winograd.hpp | 318 ++ .../oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp | 853 ++++ .../oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp | 96 + .../src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp | 1103 +++++ .../src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp | 144 + .../src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp | 1020 +++++ .../src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp | 386 ++ .../jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp | 2596 +++++++++++ .../jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp | 291 ++ .../jit_avx512_core_u8s8s32x_wino_convolution.cpp | 1284 ++++++ .../jit_avx512_core_u8s8s32x_wino_convolution.hpp | 128 + .../jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp | 820 ++++ .../jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp | 131 + .../jit_avx512_core_x8s8s32x_1x1_convolution.cpp | 292 ++ .../jit_avx512_core_x8s8s32x_1x1_convolution.hpp | 159 + .../jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp | 140 + .../cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp | 1182 +++++ .../cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp | 239 ++ .../cpu/jit_avx512_core_x8s8s32x_convolution.cpp | 423 ++ .../cpu/jit_avx512_core_x8s8s32x_convolution.hpp | 115 + .../cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp | 1034 +++++ .../cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp | 237 + thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp | 773 ++++ .../oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp | 481 +++ .../src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp | 677 +++ .../src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp | 104 + .../mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp | 134 + .../mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp | 96 + .../mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp | 497 +++ .../mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp | 93 + .../oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp | 136 + .../oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp | 103 + .../mkl-dnn/src/cpu/jit_transpose_src_utils.cpp | 1192 +++++ .../mkl-dnn/src/cpu/jit_transpose_src_utils.hpp | 145 + .../mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp | 327 ++ .../src/cpu/jit_uni_batch_normalization.cpp | 1407 ++++++ .../src/cpu/jit_uni_batch_normalization.hpp | 100 + .../mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp | 1302 ++++++ .../mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp | 253 ++ .../mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp | 427 ++ .../mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp | 266 ++ .../oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp | 1142 +++++ .../oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp | 193 + .../oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp | 949 ++++ .../oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp | 89 + thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp | 305 ++ thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp | 103 + .../mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp | 1487 +++++++ .../mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp | 183 + .../mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp | 699 +++ .../mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp | 192 + .../oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp | 264 ++ .../oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp | 182 + .../oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp | 1006 +++++ .../oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp | 127 + .../oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp | 313 ++ .../oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp | 115 + .../oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp | 32 + .../src/cpu/jit_utils/jitprofiling/LICENSE.BSD | 27 + .../src/cpu/jit_utils/jitprofiling/README.md | 1 + .../cpu/jit_utils/jitprofiling/ittnotify_config.h | 595 +++ .../cpu/jit_utils/jitprofiling/ittnotify_types.h | 94 + .../src/cpu/jit_utils/jitprofiling/jitprofiling.c | 293 ++ .../src/cpu/jit_utils/jitprofiling/jitprofiling.h | 673 +++ thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp | 317 ++ thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp | 147 + .../mkl-dnn/src/cpu/ncsp_batch_normalization.cpp | 382 ++ .../mkl-dnn/src/cpu/ncsp_batch_normalization.hpp | 160 + thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp | 392 ++ thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp | 210 + .../mkl-dnn/src/cpu/nspc_batch_normalization.cpp | 288 ++ .../mkl-dnn/src/cpu/nspc_batch_normalization.hpp | 169 + .../mkl-dnn/src/cpu/ref_batch_normalization.cpp | 265 ++ .../mkl-dnn/src/cpu/ref_batch_normalization.hpp | 127 + thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp | 97 + .../oidn/mkl-dnn/src/cpu/ref_convolution.cpp | 395 ++ .../oidn/mkl-dnn/src/cpu/ref_convolution.hpp | 194 + .../oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp | 199 + .../oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp | 502 +++ thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp | 297 ++ thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp | 168 + .../oidn/mkl-dnn/src/cpu/ref_inner_product.cpp | 285 ++ .../oidn/mkl-dnn/src/cpu/ref_inner_product.hpp | 159 + thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp | 252 ++ thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp | 136 + thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp | 381 ++ thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp | 119 + thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp | 153 + thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp | 111 + thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp | 264 ++ thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp | 186 + thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp | 101 + .../oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp | 90 + thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp | 180 + .../oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp | 170 + thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp | 143 + thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp | 113 + thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp | 191 + .../mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp | 401 ++ thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp | 788 ++++ thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp | 328 ++ .../oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp | 380 ++ thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp | 426 ++ thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp | 225 + thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp | 126 + thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp | 155 + thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp | 98 + thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp | 1022 +++++ thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp | 91 + thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp | 74 + thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp | 376 ++ thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT | 47 + thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h | 2658 ++++++++++++ .../oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h | 303 ++ .../oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h | 2017 +++++++++ thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h | 772 ++++ thirdparty/oidn/weights/rtlightmap_hdr.tza | Bin 0 -> 5660131 bytes 311 files changed, 113014 insertions(+) create mode 100644 thirdparty/oidn/.gitignore create mode 100644 thirdparty/oidn/common/barrier.h create mode 100644 thirdparty/oidn/common/exception.h create mode 100644 thirdparty/oidn/common/platform.cpp create mode 100644 thirdparty/oidn/common/platform.h create mode 100644 thirdparty/oidn/common/ref.h create mode 100644 thirdparty/oidn/common/tensor.cpp create mode 100644 thirdparty/oidn/common/tensor.h create mode 100644 thirdparty/oidn/common/thread.cpp create mode 100644 thirdparty/oidn/common/thread.h create mode 100644 thirdparty/oidn/common/timer.h create mode 100644 thirdparty/oidn/core/api.cpp create mode 100644 thirdparty/oidn/core/autoencoder.cpp create mode 100644 thirdparty/oidn/core/autoencoder.h create mode 100644 thirdparty/oidn/core/buffer.h create mode 100644 thirdparty/oidn/core/common.h create mode 100644 thirdparty/oidn/core/device.cpp create mode 100644 thirdparty/oidn/core/device.h create mode 100644 thirdparty/oidn/core/filter.cpp create mode 100644 thirdparty/oidn/core/filter.h create mode 100644 thirdparty/oidn/core/image.h create mode 100644 thirdparty/oidn/core/input_reorder.h create mode 100644 thirdparty/oidn/core/math.h create mode 100644 thirdparty/oidn/core/network.cpp create mode 100644 thirdparty/oidn/core/network.h create mode 100644 thirdparty/oidn/core/node.h create mode 100644 thirdparty/oidn/core/output_reorder.h create mode 100644 thirdparty/oidn/core/transfer_function.cpp create mode 100644 thirdparty/oidn/core/transfer_function.h create mode 100644 thirdparty/oidn/core/upsample.h create mode 100644 thirdparty/oidn/core/weights_reorder.h create mode 100644 thirdparty/oidn/include/OpenImageDenoise/oidn.h create mode 100644 thirdparty/oidn/include/OpenImageDenoise/oidn.hpp create mode 100644 thirdparty/oidn/include/OpenImageDenoise/version.h create mode 100644 thirdparty/oidn/mkl-dnn/LICENSE create mode 100644 thirdparty/oidn/mkl-dnn/include/mkldnn.h create mode 100644 thirdparty/oidn/mkl-dnn/include/mkldnn.hpp create mode 100644 thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h create mode 100644 thirdparty/oidn/mkl-dnn/include/mkldnn_types.h create mode 100644 thirdparty/oidn/mkl-dnn/include/mkldnn_version.h create mode 100644 thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in create mode 100644 thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/concat.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/engine.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/engine.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/lrn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/memory.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/memory.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/nstl.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/pooling.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/query.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/reorder.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/rnn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/softmax.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/stream.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/stream.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/sum.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/verbose.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/verbose.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h create mode 100644 thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h create mode 100644 thirdparty/oidn/weights/rtlightmap_hdr.tza (limited to 'thirdparty/oidn') diff --git a/thirdparty/oidn/.gitignore b/thirdparty/oidn/.gitignore new file mode 100644 index 0000000000..6be206fc29 --- /dev/null +++ b/thirdparty/oidn/.gitignore @@ -0,0 +1 @@ +weights/rtlightmap_hdr.cpp diff --git a/thirdparty/oidn/common/barrier.h b/thirdparty/oidn/common/barrier.h new file mode 100644 index 0000000000..b20f670053 --- /dev/null +++ b/thirdparty/oidn/common/barrier.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include +#include + +namespace oidn { + + class Barrier + { + private: + std::mutex m; + std::condition_variable cv; + volatile int count; + + public: + Barrier(int count) : count(count) {} + + void wait() + { + std::unique_lock lk(m); + count--; + + if (count == 0) + { + lk.unlock(); + cv.notify_all(); + } + else + { + cv.wait(lk, [&]{ return count == 0; }); + } + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/exception.h b/thirdparty/oidn/common/exception.h new file mode 100644 index 0000000000..18069c6a7d --- /dev/null +++ b/thirdparty/oidn/common/exception.h @@ -0,0 +1,45 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include "platform.h" + +namespace oidn { + + class Exception : public std::exception + { + private: + Error error; + const char* message; + + public: + Exception(Error error, const char* message) + : error(error), message(message) {} + + Error code() const noexcept + { + return error; + } + + const char* what() const noexcept override + { + return message; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.cpp b/thirdparty/oidn/common/platform.cpp new file mode 100644 index 0000000000..59a14ff47c --- /dev/null +++ b/thirdparty/oidn/common/platform.cpp @@ -0,0 +1,114 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "platform.h" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + void* alignedMalloc(size_t size, size_t alignment) + { + if (size == 0) + return nullptr; + + assert((alignment & (alignment-1)) == 0); + void* ptr = _mm_malloc(size, alignment); + + if (ptr == nullptr) + throw std::bad_alloc(); + + return ptr; + } + + void alignedFree(void* ptr) + { + if (ptr) + _mm_free(ptr); + } + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName() + { + std::string name; + + #if defined(__linux__) + name = "Linux"; + #elif defined(__FreeBSD__) + name = "FreeBSD"; + #elif defined(__CYGWIN__) + name = "Cygwin"; + #elif defined(_WIN32) + name = "Windows"; + #elif defined(__APPLE__) + name = "macOS"; + #elif defined(__unix__) + name = "Unix"; + #else + return "Unknown"; + #endif + + #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__) + name += " (64-bit)"; + #else + name += " (32-bit)"; + #endif + + return name; + } + + std::string getCompilerName() + { + #if defined(__INTEL_COMPILER) + int mayor = __INTEL_COMPILER / 100 % 100; + int minor = __INTEL_COMPILER % 100; + std::string version = "Intel Compiler "; + version += toString(mayor); + version += "." + toString(minor); + #if defined(__INTEL_COMPILER_UPDATE) + version += "." + toString(__INTEL_COMPILER_UPDATE); + #endif + return version; + #elif defined(__clang__) + return "Clang " __clang_version__; + #elif defined(__GNUC__) + return "GCC " __VERSION__; + #elif defined(_MSC_VER) + std::string version = toString(_MSC_FULL_VER); + version.insert(4, "."); + version.insert(9, "."); + version.insert(2, "."); + return "Visual C++ Compiler " + version; + #else + return "Unknown"; + #endif + } + + std::string getBuildName() + { + #if defined(NDEBUG) + return "Release"; + #else + return "Debug"; + #endif + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.h b/thirdparty/oidn/common/platform.h new file mode 100644 index 0000000000..205ac8981d --- /dev/null +++ b/thirdparty/oidn/common/platform.h @@ -0,0 +1,131 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include +#elif defined(__APPLE__) + #include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/OpenImageDenoise/oidn.hpp" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Macros + // ---------------------------------------------------------------------------- + + #if defined(_WIN32) + // Windows + #if !defined(__noinline) + #define __noinline __declspec(noinline) + #endif + #else + // Unix + #if !defined(__forceinline) + #define __forceinline inline __attribute__((always_inline)) + #endif + #if !defined(__noinline) + #define __noinline __attribute__((noinline)) + #endif + #endif + + #ifndef UNUSED + #define UNUSED(x) ((void)x) + #endif + #ifndef MAYBE_UNUSED + #define MAYBE_UNUSED(x) UNUSED(x) + #endif + + // ---------------------------------------------------------------------------- + // Error handling and debugging + // ---------------------------------------------------------------------------- + + struct Verbose + { + int verbose; + + Verbose(int v = 0) : verbose(v) {} + __forceinline bool isVerbose(int v = 1) const { return v <= verbose; } + }; + + #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; } + #define OIDN_FATAL(message) throw std::runtime_error(message); + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + using std::min; + using std::max; + + template + __forceinline T clamp(const T& value, const T& minValue, const T& maxValue) + { + return min(max(value, minValue), maxValue); + } + + void* alignedMalloc(size_t size, size_t alignment); + void alignedFree(void* ptr); + + template + inline std::string toString(const T& a) + { + std::stringstream sm; + sm << a; + return sm.str(); + } + +#if defined(__APPLE__) + template + bool getSysctl(const char* name, T& value) + { + int64_t result = 0; + size_t size = sizeof(result); + + if (sysctlbyname(name, &result, &size, nullptr, 0) != 0) + return false; + + value = T(result); + return true; + } +#endif + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName(); + std::string getCompilerName(); + std::string getBuildName(); + +} // namespace oidn diff --git a/thirdparty/oidn/common/ref.h b/thirdparty/oidn/common/ref.h new file mode 100644 index 0000000000..de44603af2 --- /dev/null +++ b/thirdparty/oidn/common/ref.h @@ -0,0 +1,163 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +namespace oidn { + + class RefCount + { + private: + std::atomic count; + + public: + __forceinline RefCount(int count = 0) noexcept : count(count) {} + + __forceinline size_t incRef() noexcept + { + return count.fetch_add(1) + 1; + } + + __forceinline size_t decRef() + { + const size_t newCount = decRefKeep(); + if (newCount == 0) + destroy(); + return newCount; + } + + __forceinline size_t decRefKeep() noexcept + { + return count.fetch_add(-1) - 1; + } + + __forceinline void destroy() + { + delete this; + } + + protected: + // Disable copying + RefCount(const RefCount&) = delete; + RefCount& operator =(const RefCount&) = delete; + + virtual ~RefCount() noexcept = default; + }; + + template + class Ref + { + private: + T* ptr; + + public: + __forceinline Ref() noexcept : ptr(nullptr) {} + __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {} + __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); } + __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; } + __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + template + __forceinline Ref(const Ref& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); } + + template + __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + __forceinline ~Ref() { if (ptr) ptr->decRef(); } + + __forceinline Ref& operator =(const Ref& other) + { + if (other.ptr) + other.ptr->incRef(); + if (ptr) + ptr->decRef(); + ptr = other.ptr; + return *this; + } + + __forceinline Ref& operator =(Ref&& other) + { + if (ptr) + ptr->decRef(); + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + __forceinline Ref& operator =(T* other) + { + if (other) + other->incRef(); + if (ptr) + ptr->decRef(); + ptr = other; + return *this; + } + + __forceinline Ref& operator =(std::nullptr_t) + { + if (ptr) + ptr->decRef(); + ptr = nullptr; + return *this; + } + + __forceinline operator bool() const noexcept { return ptr != nullptr; } + + __forceinline T& operator *() const noexcept { return *ptr; } + __forceinline T* operator ->() const noexcept { return ptr; } + + __forceinline T* get() const noexcept { return ptr; } + + __forceinline T* detach() noexcept + { + T* res = ptr; + ptr = nullptr; + return res; + } + }; + + template __forceinline bool operator < (const Ref& a, const Ref& b) noexcept { return a.ptr < b.ptr; } + + template __forceinline bool operator ==(const Ref& a, std::nullptr_t) noexcept { return a.ptr == nullptr; } + template __forceinline bool operator ==(std::nullptr_t, const Ref& b) noexcept { return nullptr == b.ptr; } + template __forceinline bool operator ==(const Ref& a, const Ref& b) noexcept { return a.ptr == b.ptr; } + + template __forceinline bool operator !=(const Ref& a, std::nullptr_t) noexcept { return a.ptr != nullptr; } + template __forceinline bool operator !=(std::nullptr_t, const Ref& b) noexcept { return nullptr != b.ptr; } + template __forceinline bool operator !=(const Ref& a, const Ref& b) noexcept { return a.ptr != b.ptr; } + + template + __forceinline Ref makeRef(Args&&... args) + { + return Ref(new T(std::forward(args)...)); + } + + template + __forceinline Ref staticRefCast(const Ref& a) + { + return Ref(static_cast(a.get())); + } + + template + __forceinline Ref dynamicRefCast(const Ref& a) + { + return Ref(dynamic_cast(a.get())); + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.cpp b/thirdparty/oidn/common/tensor.cpp new file mode 100644 index 0000000000..0249f2e141 --- /dev/null +++ b/thirdparty/oidn/common/tensor.cpp @@ -0,0 +1,83 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "exception.h" +#include "tensor.h" + +namespace oidn { + + std::map parseTensors(void* buffer) + { + char* input = (char*)buffer; + + // Parse the magic value + const int magic = *(unsigned short*)input; + if (magic != 0x41D7) + throw Exception(Error::InvalidOperation, "invalid tensor archive"); + input += sizeof(unsigned short); + + // Parse the version + const int majorVersion = *(unsigned char*)input++; + const int minorVersion = *(unsigned char*)input++; + UNUSED(minorVersion); + if (majorVersion > 1) + throw Exception(Error::InvalidOperation, "unsupported tensor archive version"); + + // Parse the number of tensors + const int numTensors = *(int*)input; + input += sizeof(int); + + // Parse the tensors + std::map tensorMap; + for (int i = 0; i < numTensors; ++i) + { + Tensor tensor; + + // Parse the name + const int nameLen = *(unsigned char*)input++; + std::string name(input, nameLen); + input += nameLen; + + // Parse the number of dimensions + const int ndims = *(unsigned char*)input++; + + // Parse the shape of the tensor + tensor.dims.resize(ndims); + for (int i = 0; i < ndims; ++i) + tensor.dims[i] = ((int*)input)[i]; + input += ndims * sizeof(int); + + // Parse the format of the tensor + tensor.format = std::string(input, input + ndims); + input += ndims; + + // Parse the data type of the tensor + const char type = *(unsigned char*)input++; + if (type != 'f') // only float32 is supported + throw Exception(Error::InvalidOperation, "unsupported tensor data type"); + + // Skip the data + tensor.data = (float*)input; + input += tensor.size() * sizeof(float); + + // Add the tensor to the map + tensorMap.emplace(name, std::move(tensor)); + } + + return tensorMap; + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h new file mode 100644 index 0000000000..48e7d1123d --- /dev/null +++ b/thirdparty/oidn/common/tensor.h @@ -0,0 +1,66 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include +#include + +namespace oidn { + + template + using shared_vector = std::shared_ptr>; + + // Generic tensor + struct Tensor + { + float* data; + std::vector dims; + std::string format; + shared_vector buffer; // optional, only for reference counting + + __forceinline Tensor() : data(nullptr) {} + + __forceinline Tensor(const std::vector& dims, const std::string& format) + : dims(dims), + format(format) + { + buffer = std::make_shared>(size() * sizeof(float)); + data = (float*)buffer->data(); + } + + __forceinline operator bool() const { return data != nullptr; } + + __forceinline int ndims() const { return (int)dims.size(); } + + // Returns the number of values + __forceinline size_t size() const + { + size_t size = 1; + for (int i = 0; i < ndims(); ++i) + size *= dims[i]; + return size; + } + + __forceinline float& operator [](size_t i) { return data[i]; } + __forceinline const float& operator [](size_t i) const { return data[i]; } + }; + + // Parses tensors from a buffer + std::map parseTensors(void* buffer); + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.cpp b/thirdparty/oidn/common/thread.cpp new file mode 100644 index 0000000000..48c489c57b --- /dev/null +++ b/thirdparty/oidn/common/thread.cpp @@ -0,0 +1,297 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#if defined(_MSC_VER) + #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned +#endif + +#if defined(__APPLE__) + #include + #include +#endif + +#include "thread.h" +#include + +namespace oidn { + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + HMODULE hLib = GetModuleHandle(TEXT("kernel32")); + pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx"); + pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity"); + + if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity) + { + // Get logical processor information + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr; + DWORD bufferSize = 0; + + // First call the function with an empty buffer to get the required buffer size + BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + return; + } + + // Allocate the buffer + buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize); + if (!buffer) + { + OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed"); + return; + } + + // Call again the function but now with the properly sized buffer + result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (!result) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + free(buffer); + return; + } + + // Iterate over the logical processor information structures + // There should be one structure for each physical core + char* ptr = (char*)buffer; + while (ptr < (char*)buffer + bufferSize) + { + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr; + if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0) + { + // Iterate over the groups + int numThreads = 0; + for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group) + { + GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group]; + while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore)) + { + // Extract the next set bit/thread from the mask + GROUP_AFFINITY threadAffinity = coreAffinity; + threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask; + + // Push the affinity for this thread + affinities.push_back(threadAffinity); + oldAffinities.push_back(threadAffinity); + numThreads++; + + // Remove this bit/thread from the mask + coreAffinity.Mask ^= threadAffinity.Mask; + } + } + } + + // Next structure + ptr += item->Size; + } + + // Free the buffer + free(buffer); + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Save the current affinity and set the new one + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex])) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Restore the original affinity + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr)) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + std::vector threadIds; + + // Parse the thread/CPU topology + for (int cpuId = 0; ; cpuId++) + { + std::fstream fs; + std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list"); + fs.open(cpu.c_str(), std::fstream::in); + if (fs.fail()) break; + + int i; + int j = 0; + while ((j < numThreadsPerCore) && (fs >> i)) + { + if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; })) + threadIds.push_back(i); + + if (fs.peek() == ',') + fs.ignore(); + j++; + } + + fs.close(); + } + + #if 0 + for (size_t i = 0; i < thread_ids.size(); ++i) + std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl; + #endif + + // Create the affinity structures + affinities.resize(threadIds.size()); + oldAffinities.resize(threadIds.size()); + + for (size_t i = 0; i < threadIds.size(); ++i) + { + cpu_set_t affinity; + CPU_ZERO(&affinity); + CPU_SET(threadIds[i], &affinity); + + affinities[i] = affinity; + oldAffinities[i] = affinity; + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Save the current affinity + if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + { + OIDN_WARNING("pthread_getaffinity_np failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Restore the original affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + // Query the thread/CPU topology + int numPhysicalCpus; + int numLogicalCpus; + + if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus)) + { + OIDN_WARNING("sysctlbyname failed"); + return; + } + + if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1)) + return; // this shouldn't happen + const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus; + + // Create the affinity structures + // macOS doesn't support binding a thread to a specific core, but we can at least group threads which + // should be on the same core together + for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1! + { + thread_affinity_policy affinity; + affinity.affinity_tag = core; + + for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread) + { + affinities.push_back(affinity); + oldAffinities.push_back(affinity); + } + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Save the current affinity + mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT; + boolean_t getDefault = FALSE; + if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS) + { + OIDN_WARNING("thread_policy_get failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Restore the original affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.h b/thirdparty/oidn/common/thread.h new file mode 100644 index 0000000000..2c731367da --- /dev/null +++ b/thirdparty/oidn/common/thread.h @@ -0,0 +1,202 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +#if !defined(_WIN32) + #include + #include + #if defined(__APPLE__) + #include + #endif +#endif + +#include +#include + +namespace oidn { + + // -------------------------------------------------------------------------- + // ThreadLocal + // -------------------------------------------------------------------------- + + // Wrapper which makes any variable thread-local + template + class ThreadLocal : public Verbose + { + private: + #if defined(_WIN32) + DWORD key; + #else + pthread_key_t key; + #endif + + std::vector instances; + std::mutex mutex; + + public: + ThreadLocal(int verbose = 0) + : Verbose(verbose) + { + #if defined(_WIN32) + key = TlsAlloc(); + if (key == TLS_OUT_OF_INDEXES) + OIDN_FATAL("TlsAlloc failed"); + #else + if (pthread_key_create(&key, nullptr) != 0) + OIDN_FATAL("pthread_key_create failed"); + #endif + } + + ~ThreadLocal() + { + std::lock_guard lock(mutex); + for (T* ptr : instances) + delete ptr; + + #if defined(_WIN32) + if (!TlsFree(key)) + OIDN_WARNING("TlsFree failed"); + #else + if (pthread_key_delete(key) != 0) + OIDN_WARNING("pthread_key_delete failed"); + #endif + } + + T& get() + { + #if defined(_WIN32) + T* ptr = (T*)TlsGetValue(key); + #else + T* ptr = (T*)pthread_getspecific(key); + #endif + + if (ptr) + return *ptr; + + ptr = new T; + std::lock_guard lock(mutex); + instances.push_back(ptr); + + #if defined(_WIN32) + if (!TlsSetValue(key, ptr)) + OIDN_FATAL("TlsSetValue failed"); + #else + if (pthread_setspecific(key, ptr) != 0) + OIDN_FATAL("pthread_setspecific failed"); + #endif + + return *ptr; + } + }; + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP, + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, + PDWORD); + + typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE, + CONST GROUP_AFFINITY*, + PGROUP_AFFINITY); + + GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr; + SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr; + + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector affinities; // thread affinities + std::vector oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/timer.h b/thirdparty/oidn/common/timer.h new file mode 100644 index 0000000000..62aaaa1c33 --- /dev/null +++ b/thirdparty/oidn/common/timer.h @@ -0,0 +1,49 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include + +namespace oidn { + + class Timer + { + private: + using clock = std::chrono::high_resolution_clock; + + std::chrono::time_point start; + + public: + Timer() + { + reset(); + } + + void reset() + { + start = clock::now(); + } + + double query() const + { + auto end = clock::now(); + return std::chrono::duration_cast>(end - start).count(); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/api.cpp b/thirdparty/oidn/core/api.cpp new file mode 100644 index 0000000000..7353fe4e25 --- /dev/null +++ b/thirdparty/oidn/core/api.cpp @@ -0,0 +1,408 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#ifdef _WIN32 +# define OIDN_API extern "C" __declspec(dllexport) +#else +# define OIDN_API extern "C" __attribute__ ((visibility ("default"))) +#endif + +// Locks the device that owns the specified object +// Use *only* inside OIDN_TRY/CATCH! +#define OIDN_LOCK(obj) \ + std::lock_guard lock(obj->getDevice()->getMutex()); + +// Try/catch for converting exceptions to errors +#define OIDN_TRY \ + try { + +#define OIDN_CATCH(obj) \ + } catch (Exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \ + } catch (std::bad_alloc&) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + } catch (mkldnn::error& e) { \ + if (e.status == mkldnn_out_of_memory) \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + else \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \ + } catch (std::exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \ + } catch (...) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \ + } + +#include "device.h" +#include "filter.h" +#include + +namespace oidn { + + namespace + { + __forceinline void checkHandle(void* handle) + { + if (handle == nullptr) + throw Exception(Error::InvalidArgument, "invalid handle"); + } + + template + __forceinline void retainObject(T* obj) + { + if (obj) + { + obj->incRef(); + } + else + { + OIDN_TRY + checkHandle(obj); + OIDN_CATCH(obj) + } + } + + template + __forceinline void releaseObject(T* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + OIDN_LOCK(obj); + obj->destroy(); + OIDN_CATCH(obj) + } + } + + template<> + __forceinline void releaseObject(Device* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + // Do NOT lock the device because it owns the mutex + obj->destroy(); + OIDN_CATCH(obj) + } + } + } + + OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type) + { + Ref device = nullptr; + OIDN_TRY + if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT) + device = makeRef(); + else + throw Exception(Error::InvalidArgument, "invalid device type"); + OIDN_CATCH(device) + return (OIDNDevice)device.detach(); + } + + OIDN_API void oidnRetainDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + retainObject(device); + } + + OIDN_API void oidnReleaseDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + releaseObject(device); + } + + OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return false; + } + + OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return 0; + } + + OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->setErrorFunction((ErrorFunction)func, userPtr); + OIDN_CATCH(device) + } + + OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage) + { + Device* device = (Device*)hDevice; + OIDN_TRY + return (OIDNError)Device::getError(device, outMessage); + OIDN_CATCH(device) + if (outMessage) *outMessage = ""; + return OIDN_ERROR_UNKNOWN; + } + + OIDN_API void oidnCommitDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->commit(); + OIDN_CATCH(device) + } + + OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref buffer = device->newBuffer(byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref buffer = device->newBuffer(ptr, byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + retainObject(buffer); + } + + OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + releaseObject(buffer); + } + + OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->map(byteOffset, byteSize); + OIDN_CATCH(buffer) + return nullptr; + } + + OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->unmap(mappedPtr); + OIDN_CATCH(buffer) + } + + OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref filter = device->newFilter(type); + return (OIDNFilter)filter.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + retainObject(filter); + } + + OIDN_API void oidnReleaseFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + releaseObject(filter); + } + + OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name, + OIDNBuffer hBuffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + checkHandle(hBuffer); + OIDN_LOCK(filter); + Ref buffer = (Buffer*)hBuffer; + if (buffer->getDevice() != filter->getDevice()) + throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices"); + Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, int(value)); + OIDN_CATCH(filter) + } + + OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return false; + } + + OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, value); + OIDN_CATCH(filter) + } + + OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1f(name, value); + OIDN_CATCH(filter) + } + + OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1f(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->setProgressMonitorFunction(func, userPtr); + OIDN_CATCH(filter) + } + + OIDN_API void oidnCommitFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->commit(); + OIDN_CATCH(filter) + } + + OIDN_API void oidnExecuteFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->execute(); + OIDN_CATCH(filter) + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.cpp b/thirdparty/oidn/core/autoencoder.cpp new file mode 100644 index 0000000000..8ae2421fa6 --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.cpp @@ -0,0 +1,519 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "autoencoder.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter + // -------------------------------------------------------------------------- + + AutoencoderFilter::AutoencoderFilter(const Ref& device) + : Filter(device) + { + } + + void AutoencoderFilter::setImage(const std::string& name, const Image& data) + { + if (name == "color") + color = data; + else if (name == "albedo") + albedo = data; + else if (name == "normal") + normal = data; + else if (name == "output") + output = data; + + dirty = true; + } + + void AutoencoderFilter::set1i(const std::string& name, int value) + { + if (name == "hdr") + hdr = value; + else if (name == "srgb") + srgb = value; + else if (name == "maxMemoryMB") + maxMemoryMB = value; + + dirty = true; + } + + int AutoencoderFilter::get1i(const std::string& name) + { + if (name == "hdr") + return hdr; + else if (name == "srgb") + return srgb; + else if (name == "maxMemoryMB") + return maxMemoryMB; + else if (name == "alignment") + return alignment; + else if (name == "overlap") + return overlap; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::set1f(const std::string& name, float value) + { + if (name == "hdrScale") + hdrScale = value; + + dirty = true; + } + + float AutoencoderFilter::get1f(const std::string& name) + { + if (name == "hdrScale") + return hdrScale; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::commit() + { + if (!dirty) + return; + + { + if (mayiuse(avx512_common)) + net = buildNet<16>(); + else + net = buildNet<8>(); + } + + dirty = false; + } + + void AutoencoderFilter::execute() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the filter are not committed"); + + if (!net) + return; + + { + Progress progress; + progress.func = progressFunc; + progress.userPtr = progressUserPtr; + progress.taskCount = tileCountH * tileCountW; + + // Iterate over the tiles + int tileIndex = 0; + + for (int i = 0; i < tileCountH; ++i) + { + const int h = i * (tileH - 2*overlap); // input tile position (including overlap) + const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top + const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom + const int tileH1 = min(H - h, tileH); // input tile size (including overlap) + const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size + const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer + + for (int j = 0; j < tileCountW; ++j) + { + const int w = j * (tileW - 2*overlap); // input tile position (including overlap) + const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left + const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right + const int tileW1 = min(W - w, tileW); // input tile size (including overlap) + const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size + const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer + + // Set the input tile + inputReorder->setTile(h, w, + alignOffsetH, alignOffsetW, + tileH1, tileW1); + + // Set the output tile + outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW, + h + overlapBeginH, w + overlapBeginW, + tileH2, tileW2); + + //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2); + + // Denoise the tile + net->execute(progress, tileIndex); + + // Next tile + tileIndex++; + } + } + } + } + + void AutoencoderFilter::computeTileSize() + { + const int minTileSize = 3*overlap; + const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8; + const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel; + + tileCountH = 1; + tileCountW = 1; + tileH = roundUp(H, alignment); + tileW = roundUp(W, alignment); + + // Divide the image into tiles until the tile size gets below the threshold + while (int64_t(tileH) * tileW > maxTilePixels) + { + if (tileH > minTileSize && tileH > tileW) + { + tileCountH++; + tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize); + } + else if (tileW > minTileSize) + { + tileCountW++; + tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize); + } + else + break; + } + + // Compute the final number of tiles + tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1; + tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1; + + if (device->isVerbose(2)) + { + std::cout << "Tile size : " << tileW << "x" << tileH << std::endl; + std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl; + } + } + + template + std::shared_ptr AutoencoderFilter::buildNet() + { + H = color.height; + W = color.width; + + // Configure the network + int inputC; + void* weightPtr; + + if (srgb && hdr) + throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time"); + + if (color && !albedo && !normal && weightData.hdr) + { + inputC = 3; + weightPtr = hdr ? weightData.hdr : weightData.ldr; + } + else if (color && albedo && !normal && weightData.hdr_alb) + { + inputC = 6; + weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb; + } + else if (color && albedo && normal && weightData.hdr_alb_nrm) + { + inputC = 9; + weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm; + } + else + { + throw Exception(Error::InvalidOperation, "unsupported combination of input features"); + } + + if (!output) + throw Exception(Error::InvalidOperation, "output image not specified"); + + if ((color.format != Format::Float3) + || (albedo && albedo.format != Format::Float3) + || (normal && normal.format != Format::Float3) + || (output.format != Format::Float3)) + throw Exception(Error::InvalidOperation, "unsupported image format"); + + if ((albedo && (albedo.width != W || albedo.height != H)) + || (normal && (normal.width != W || normal.height != H)) + || (output.width != W || output.height != H)) + throw Exception(Error::InvalidOperation, "image size mismatch"); + + // Compute the tile size + computeTileSize(); + + // If the image size is zero, there is nothing else to do + if (H <= 0 || W <= 0) + return nullptr; + + // Parse the weights + const auto weightMap = parseTensors(weightPtr); + + // Create the network + std::shared_ptr> net = std::make_shared>(device, weightMap); + + // Compute the tensor sizes + const auto inputDims = memory::dims({1, inputC, tileH, tileW}); + const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0 + + const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0 + const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1 + const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1 + const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0 + const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2 + const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0 + const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3 + const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0 + const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4 + const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0 + const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1 + const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4 + const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims); + const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0 + const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1 + const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3 + const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims); + const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0 + const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1 + const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2 + const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims); + const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0 + const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1 + const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1 + const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims); + const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0 + const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1 + const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0 + const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims); + const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0 + const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1 + const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0 + + const auto outputDims = memory::dims({1, 3, tileH, tileW}); + + // Allocate two temporary ping-pong buffers to decrease memory usage + const auto temp0Dims = getMaxTensorDims({ + conv1Dims, + conv2Dims, + conv3Dims, + conv4Dims, + conv5Dims, + conv6Dims, + conv7Dims, + conv8Dims, + conv9Dims, + conv10Dims, + conv11Dims + }); + + const auto temp1Dims = getMaxTensorDims({ + conv1bDims, + pool5Dims, + conv6bDims, + conv7bDims, + conv8bDims, + conv9bDims, + conv10bDims, + }); + + auto temp0 = net->allocTensor(temp0Dims); + auto temp1 = net->allocTensor(temp1Dims); + + // Allocate enough memory to hold the concat outputs. Then use the first + // half to hold the previous conv output and the second half to hold the + // pool/orig image output. This works because everything is C dimension + // outermost, padded to K floats, and all the concats are on the C dimension. + auto concat0Dst = net->allocTensor(concat0Dims); + auto concat1Dst = net->allocTensor(concat1Dims); + auto concat2Dst = net->allocTensor(concat2Dims); + auto concat3Dst = net->allocTensor(concat3Dims); + auto concat4Dst = net->allocTensor(concat4Dims); + + // Transfer function + std::shared_ptr transferFunc = makeTransferFunc(); + + // Autoexposure + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + { + if (isnan(hdrScale)) + net->addAutoexposure(color, tf); + else + tf->setExposure(hdrScale); + } + + // Input reorder + auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims); + inputReorder = net->addInputReorder(color, albedo, normal, + transferFunc, + alignment, inputReorderDst); + + // conv1 + auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0); + + // conv1b + auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1); + + // pool1 + // Adjust pointer for pool1 to eliminate concat1 + auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims); + auto pool1 = net->addPool(conv1b->getDst(), pool1Dst); + + // conv2 + auto conv2 = net->addConv("conv2", pool1->getDst(), temp0); + + // pool2 + // Adjust pointer for pool2 to eliminate concat2 + auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims); + auto pool2 = net->addPool(conv2->getDst(), pool2Dst); + + // conv3 + auto conv3 = net->addConv("conv3", pool2->getDst(), temp0); + + // pool3 + // Adjust pointer for pool3 to eliminate concat3 + auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims); + auto pool3 = net->addPool(conv3->getDst(), pool3Dst); + + // conv4 + auto conv4 = net->addConv("conv4", pool3->getDst(), temp0); + + // pool4 + // Adjust pointer for pool4 to eliminate concat4 + auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims); + auto pool4 = net->addPool(conv4->getDst(), pool4Dst); + + // conv5 + auto conv5 = net->addConv("conv5", pool4->getDst(), temp0); + + // pool5 + auto pool5 = net->addPool(conv5->getDst(), temp1); + + // upsample4 + auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst); + auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst); + + // conv6 + auto conv6 = net->addConv("conv6", concat4Dst, temp0); + + // conv6b + auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1); + + // upsample3 + auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst); + auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst); + + // conv7 + auto conv7 = net->addConv("conv7", concat3Dst, temp0); + + // conv7b + auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1); + + // upsample2 + auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst); + auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst); + + // conv8 + auto conv8 = net->addConv("conv8", concat2Dst, temp0); + + // conv8b + auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1); + + // upsample1 + auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst); + auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst); + + // conv9 + auto conv9 = net->addConv("conv9", concat1Dst, temp0); + + // conv9b + auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1); + + // upsample0 + auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst); + auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst); + + // conv10 + auto conv10 = net->addConv("conv10", concat0Dst, temp0); + + // conv10b + auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1); + + // conv11 + auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */); + + // Output reorder + outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output); + + net->finalize(); + return net; + } + + std::shared_ptr AutoencoderFilter::makeTransferFunc() + { + if (hdr) + return std::make_shared(); + else if (srgb) + return std::make_shared(); + else + return std::make_shared(); + } + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + // -------------------------------------------------------------------------- + // RTFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // LDR + extern unsigned char rt_ldr[]; // color + extern unsigned char rt_ldr_alb[]; // color, albedo + extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal + + // HDR + extern unsigned char rt_hdr[]; // color + extern unsigned char rt_hdr_alb[]; // color, albedo + extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal + } + + RTFilter::RTFilter(const Ref& device) + : AutoencoderFilter(device) + { + weightData.ldr = weights::rt_ldr; + weightData.ldr_alb = weights::rt_ldr_alb; + weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm; + weightData.hdr = weights::rt_hdr; + weightData.hdr_alb = weights::rt_hdr_alb; + weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; + } +#endif + + // -------------------------------------------------------------------------- + // RTLightmapFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // HDR + extern unsigned char rtlightmap_hdr[]; // color + } + + RTLightmapFilter::RTLightmapFilter(const Ref& device) + : AutoencoderFilter(device) + { + weightData.hdr = weights::rtlightmap_hdr; + + hdr = true; + } + + std::shared_ptr RTLightmapFilter::makeTransferFunc() + { + return std::make_shared(); + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.h b/thirdparty/oidn/core/autoencoder.h new file mode 100644 index 0000000000..97432f2bbd --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.h @@ -0,0 +1,116 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "filter.h" +#include "network.h" +#include "transfer_function.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter - Direct-predicting autoencoder + // -------------------------------------------------------------------------- + + class AutoencoderFilter : public Filter + { + protected: + static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary) + static constexpr int receptiveField = 222; // receptive field in pixels + static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels + + static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage + static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8 + static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16 + + Image color; + Image albedo; + Image normal; + Image output; + bool hdr = false; + float hdrScale = std::numeric_limits::quiet_NaN(); + bool srgb = false; + int maxMemoryMB = 6000; // approximate maximum memory usage in MBs + + int H = 0; // image height + int W = 0; // image width + int tileH = 0; // tile height + int tileW = 0; // tile width + int tileCountH = 1; // number of tiles in H dimension + int tileCountW = 1; // number of tiles in W dimension + + std::shared_ptr net; + std::shared_ptr inputReorder; + std::shared_ptr outputReorder; + + struct + { + void* ldr = nullptr; + void* ldr_alb = nullptr; + void* ldr_alb_nrm = nullptr; + void* hdr = nullptr; + void* hdr_alb = nullptr; + void* hdr_alb_nrm = nullptr; + } weightData; + + explicit AutoencoderFilter(const Ref& device); + virtual std::shared_ptr makeTransferFunc(); + + public: + void setImage(const std::string& name, const Image& data) override; + void set1i(const std::string& name, int value) override; + int get1i(const std::string& name) override; + void set1f(const std::string& name, float value) override; + float get1f(const std::string& name) override; + + void commit() override; + void execute() override; + + private: + void computeTileSize(); + + template + std::shared_ptr buildNet(); + + bool isCommitted() const { return bool(net); } + }; + + // -------------------------------------------------------------------------- + // RTFilter - Generic ray tracing denoiser + // -------------------------------------------------------------------------- + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + class RTFilter : public AutoencoderFilter + { + public: + explicit RTFilter(const Ref& device); + }; +#endif + + // -------------------------------------------------------------------------- + // RTLightmapFilter - Ray traced lightmap denoiser + // -------------------------------------------------------------------------- + + class RTLightmapFilter : public AutoencoderFilter + { + public: + explicit RTLightmapFilter(const Ref& device); + std::shared_ptr makeTransferFunc() override; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/buffer.h b/thirdparty/oidn/core/buffer.h new file mode 100644 index 0000000000..b95109152e --- /dev/null +++ b/thirdparty/oidn/core/buffer.h @@ -0,0 +1,75 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" + +namespace oidn { + + class Device; + + // Buffer which may or may not own its data + class Buffer : public RefCount + { + private: + char* ptr; + size_t byteSize; + bool shared; + Ref device; + + public: + __forceinline Buffer(const Ref& device, size_t size) + : ptr((char*)alignedMalloc(size, 64)), + byteSize(size), + shared(false), + device(device) {} + + __forceinline Buffer(const Ref& device, void* data, size_t size) + : ptr((char*)data), + byteSize(size), + shared(true), + device(device) + { + if (data == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + } + + __forceinline ~Buffer() + { + if (!shared) + alignedFree(ptr); + } + + __forceinline char* data() { return ptr; } + __forceinline const char* data() const { return ptr; } + __forceinline size_t size() const { return byteSize; } + + void* map(size_t offset, size_t size) + { + if (offset + size > byteSize) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + + return ptr + offset; + } + + void unmap(void* mappedPtr) {} + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h new file mode 100644 index 0000000000..6c87f377bc --- /dev/null +++ b/thirdparty/oidn/core/common.h @@ -0,0 +1,133 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +#include "mkl-dnn/include/mkldnn.hpp" +#include "mkl-dnn/include/mkldnn_debug.h" +#include "mkl-dnn/src/common/mkldnn_thread.hpp" +#include "mkl-dnn/src/common/type_helpers.hpp" +#include "mkl-dnn/src/cpu/jit_generator.hpp" + +#include "common/ref.h" +#include "common/exception.h" +#include "common/thread.h" +#include "math.h" + +namespace oidn { + + using namespace mkldnn; + using namespace mkldnn::impl::cpu; + using mkldnn::impl::parallel_nd; + using mkldnn::impl::memory_desc_matches_tag; + + + inline size_t getFormatBytes(Format format) + { + switch (format) + { + case Format::Undefined: return 1; + case Format::Float: return sizeof(float); + case Format::Float2: return sizeof(float)*2; + case Format::Float3: return sizeof(float)*3; + case Format::Float4: return sizeof(float)*4; + } + assert(0); + return 0; + } + + + inline memory::dims getTensorDims(const std::shared_ptr& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]); + } + + inline memory::data_type getTensorType(const std::shared_ptr& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::data_type(desc.data_type); + } + + // Returns the number of values in a tensor + inline size_t getTensorSize(const memory::dims& dims) + { + size_t res = 1; + for (int i = 0; i < (int)dims.size(); ++i) + res *= dims[i]; + return res; + } + + inline memory::dims getMaxTensorDims(const std::vector& dims) + { + memory::dims result; + size_t maxSize = 0; + + for (const auto& d : dims) + { + const size_t size = getTensorSize(d); + if (size > maxSize) + { + result = d; + maxSize = size; + } + } + + return result; + } + + inline size_t getTensorSize(const std::shared_ptr& mem) + { + return getTensorSize(getTensorDims(mem)); + } + + + template + inline int getPadded(int dim) + { + return (dim + (K-1)) & ~(K-1); + } + + template + inline memory::dims getPadded_nchw(const memory::dims& dims) + { + assert(dims.size() == 4); + memory::dims padDims = dims; + padDims[1] = getPadded(dims[1]); // pad C + return padDims; + } + + + template + struct BlockedFormat; + + template<> + struct BlockedFormat<8> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o; + }; + + template<> + struct BlockedFormat<16> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.cpp b/thirdparty/oidn/core/device.cpp new file mode 100644 index 0000000000..0812624bb5 --- /dev/null +++ b/thirdparty/oidn/core/device.cpp @@ -0,0 +1,205 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "device.h" +#include "autoencoder.h" + +namespace oidn { + + thread_local Device::ErrorState Device::globalError; + + Device::Device() + { + if (!mayiuse(sse41)) + throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum"); + } + + Device::~Device() + { + } + + void Device::setError(Device* device, Error code, const std::string& message) + { + // Update the stored error only if the previous error was queried + if (device) + { + ErrorState& curError = device->error.get(); + + if (curError.code == Error::None) + { + curError.code = code; + curError.message = message; + } + + // Print the error message in verbose mode + if (device->isVerbose()) + std::cerr << "Error: " << message << std::endl; + + // Call the error callback function + ErrorFunction errorFunc; + void* errorUserPtr; + + { + std::lock_guard lock(device->mutex); + errorFunc = device->errorFunc; + errorUserPtr = device->errorUserPtr; + } + + if (errorFunc) + errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str()); + } + else + { + if (globalError.code == Error::None) + { + globalError.code = code; + globalError.message = message; + } + } + } + + Error Device::getError(Device* device, const char** outMessage) + { + // Return and clear the stored error code, but keep the error message so pointers to it will + // remain valid until the next getError call + if (device) + { + ErrorState& curError = device->error.get(); + const Error code = curError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : curError.message.c_str(); + curError.code = Error::None; + return code; + } + else + { + const Error code = globalError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str(); + globalError.code = Error::None; + return code; + } + } + + void Device::setErrorFunction(ErrorFunction func, void* userPtr) + { + errorFunc = func; + errorUserPtr = userPtr; + } + + int Device::get1i(const std::string& name) + { + if (name == "numThreads") + return numThreads; + else if (name == "setAffinity") + return setAffinity; + else if (name == "verbose") + return verbose; + else if (name == "version") + return OIDN_VERSION; + else if (name == "versionMajor") + return OIDN_VERSION_MAJOR; + else if (name == "versionMinor") + return OIDN_VERSION_MINOR; + else if (name == "versionPatch") + return OIDN_VERSION_PATCH; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void Device::set1i(const std::string& name, int value) + { + if (name == "numThreads") + numThreads = value; + else if (name == "setAffinity") + setAffinity = value; + else if (name == "verbose") + { + verbose = value; + error.verbose = value; + } + + dirty = true; + } + + void Device::commit() + { + if (isCommitted()) + throw Exception(Error::InvalidOperation, "device can be committed only once"); + + // Create the task arena + const int maxNumThreads = 1; //affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency(); + numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads; + + dirty = false; + + if (isVerbose()) + print(); + } + + void Device::checkCommitted() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the device are not committed"); + } + + Ref Device::newBuffer(size_t byteSize) + { + checkCommitted(); + return makeRef(Ref(this), byteSize); + } + + Ref Device::newBuffer(void* ptr, size_t byteSize) + { + checkCommitted(); + return makeRef(Ref(this), ptr, byteSize); + } + + Ref Device::newFilter(const std::string& type) + { + checkCommitted(); + + if (isVerbose()) + std::cout << "Filter: " << type << std::endl; + + Ref filter; + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + if (type == "RT") + filter = makeRef(Ref(this)); +#endif + if (type == "RTLightmap") + filter = makeRef(Ref(this)); + else + throw Exception(Error::InvalidArgument, "unknown filter type"); + + return filter; + } + + void Device::print() + { + std::cout << std::endl; + + std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl; + std::cout << " Compiler: " << getCompilerName() << std::endl; + std::cout << " Build : " << getBuildName() << std::endl; + std::cout << " Platform: " << getPlatformName() << std::endl; + + std::cout << std::endl; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.h b/thirdparty/oidn/core/device.h new file mode 100644 index 0000000000..93a83eb731 --- /dev/null +++ b/thirdparty/oidn/core/device.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" + +namespace oidn { + + class Buffer; + class Filter; + + class Device : public RefCount, public Verbose + { + private: + // Thread-safety + std::mutex mutex; + + // Error handling + struct ErrorState + { + Error code = Error::None; + std::string message; + }; + + static thread_local ErrorState globalError; + ThreadLocal error; + ErrorFunction errorFunc = nullptr; + void* errorUserPtr = nullptr; + + // Parameters + int numThreads = 0; // autodetect by default + bool setAffinity = true; + + bool dirty = true; + + public: + Device(); + ~Device(); + + static void setError(Device* device, Error code, const std::string& message); + static Error getError(Device* device, const char** outMessage); + + void setErrorFunction(ErrorFunction func, void* userPtr); + + int get1i(const std::string& name); + void set1i(const std::string& name, int value); + + void commit(); + + Ref newBuffer(size_t byteSize); + Ref newBuffer(void* ptr, size_t byteSize); + Ref newFilter(const std::string& type); + + __forceinline Device* getDevice() { return this; } + __forceinline std::mutex& getMutex() { return mutex; } + + private: + bool isCommitted() const { return false; } + void checkCommitted(); + + void print(); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.cpp b/thirdparty/oidn/core/filter.cpp new file mode 100644 index 0000000000..ec1f10af87 --- /dev/null +++ b/thirdparty/oidn/core/filter.cpp @@ -0,0 +1,27 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "filter.h" + +namespace oidn { + + void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr) + { + progressFunc = func; + progressUserPtr = userPtr; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.h b/thirdparty/oidn/core/filter.h new file mode 100644 index 0000000000..935fa202f4 --- /dev/null +++ b/thirdparty/oidn/core/filter.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" +#include "image.h" + +namespace oidn { + + class Filter : public RefCount + { + protected: + Ref device; + + ProgressMonitorFunction progressFunc = nullptr; + void* progressUserPtr = nullptr; + + bool dirty = true; + + public: + explicit Filter(const Ref& device) : device(device) {} + + virtual void setImage(const std::string& name, const Image& data) = 0; + virtual void set1i(const std::string& name, int value) = 0; + virtual int get1i(const std::string& name) = 0; + virtual void set1f(const std::string& name, float value) = 0; + virtual float get1f(const std::string& name) = 0; + + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr); + + virtual void commit() = 0; + virtual void execute() = 0; + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/image.h b/thirdparty/oidn/core/image.h new file mode 100644 index 0000000000..748f49c4e5 --- /dev/null +++ b/thirdparty/oidn/core/image.h @@ -0,0 +1,111 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "buffer.h" + +namespace oidn { + + struct Image + { + static constexpr int maxSize = 65536; + + char* ptr; // pointer to the first pixel + int width; // width in number of pixels + int height; // height in number of pixels + size_t bytePixelStride; // pixel stride in number of *bytes* + size_t rowStride; // row stride in number of *pixel strides* + Format format; // pixel format + Ref buffer; // buffer containing the image data + + Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {} + + Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + if (ptr == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + + init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + } + + Image(const Ref& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + + if (byteOffset + height * rowStride * bytePixelStride > buffer->size()) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + } + + void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride) + { + assert(width >= 0); + assert(height >= 0); + if (width > maxSize || height > maxSize) + throw Exception(Error::InvalidArgument, "image size too large"); + + this->ptr = ptr; + this->width = width; + this->height = height; + + const size_t pixelSize = getFormatBytes(format); + if (inBytePixelStride != 0) + { + if (inBytePixelStride < pixelSize) + throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size"); + + this->bytePixelStride = inBytePixelStride; + } + else + { + this->bytePixelStride = pixelSize; + } + + if (inByteRowStride != 0) + { + if (inByteRowStride < width * this->bytePixelStride) + throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride"); + if (inByteRowStride % this->bytePixelStride != 0) + throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride"); + + this->rowStride = inByteRowStride / this->bytePixelStride; + } + else + { + this->rowStride = width; + } + + this->format = format; + } + + __forceinline char* get(int y, int x) + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + __forceinline const char* get(int y, int x) const + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + operator bool() const + { + return ptr != nullptr; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/input_reorder.h b/thirdparty/oidn/core/input_reorder.h new file mode 100644 index 0000000000..966856afe9 --- /dev/null +++ b/thirdparty/oidn/core/input_reorder.h @@ -0,0 +1,232 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Input reorder node + template + class InputReorderNode : public Node + { + private: + // Source + Image color; + Image albedo; + Image normal; + + // Destination + std::shared_ptr dst; + float* dstPtr; + int C2; + int H2; + int W2; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr transferFunc; + + public: + InputReorderNode(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& dst, + const std::shared_ptr& transferFunc) + : color(color), albedo(albedo), normal(normal), + dst(dst), + h1Begin(0), w1Begin(0), + H(color.height), W(color.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(dstDesc.ndims == 4); + assert(dstDesc.data_type == memory::data_type::f32); + assert(dstDesc.dims[0] == 1); + //assert(dstDesc.dims[1] >= getPadded(C1)); + + dstPtr = (float*)dst->get_data_handle(); + C2 = dstDesc.dims[1]; + H2 = dstDesc.dims[2]; + W2 = dstDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(H + h1Begin <= color.height); + assert(W + w1Begin <= color.width); + assert(H + h2Begin <= H2); + assert(W + w2Begin <= W2); + + parallel_nd(H2, [&](int h2) + { + const int h = h2 - h2Begin; + + if (h >= 0 && h < H) + { + const int h1 = h + h1Begin; + + // Zero pad + for (int w2 = 0; w2 < w2Begin; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Reorder + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + + int c = 0; + storeColor(h2, w2, c, (float*)color.get(h1, w1)); + if (albedo) + storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1)); + if (normal) + storeNormal(h2, w2, c, (float*)normal.get(h1, w1)); + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Zero pad + for (int w2 = W + w2Begin; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + else + { + // Zero pad + for (int w2 = 0; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + }); + } + + std::shared_ptr getDst() const override { return dst; } + + private: + // Stores a single value + __forceinline void store(int h, int w, int& c, float value) + { + // Destination is in nChwKc format + float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K); + *dst_c = value; + c++; + } + + // Stores a color + __forceinline void storeColor(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = maxSafe(x, 0.f); + + // Apply the transfer function + x = transferFunc->forward(x); + + // Store the value + store(h, w, c, x); + } + } + + // Stores an albedo + __forceinline void storeAlbedo(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = clampSafe(x, 0.f, 1.f); + + // Store the value + store(h, w, c, x); + } + } + + // Stores a normal + __forceinline void storeNormal(int h, int w, int& c, const float* values) + { + // Load the normal + float x = values[0]; + float y = values[1]; + float z = values[2]; + + // Compute the length of the normal + const float lengthSqr = sqr(x) + sqr(y) + sqr(z); + + // Normalize the normal and transform it to [0..1] + if (isfinite(lengthSqr)) + { + const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f; + + const float scale = invLength * 0.5f; + const float offset = 0.5f; + + x = x * scale + offset; + y = y * scale + offset; + z = z * scale + offset; + } + else + { + x = 0.f; + y = 0.f; + z = 0.f; + } + + // Store the normal + store(h, w, c, x); + store(h, w, c, y); + store(h, w, c, z); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/math.h b/thirdparty/oidn/core/math.h new file mode 100644 index 0000000000..a844ef0d1d --- /dev/null +++ b/thirdparty/oidn/core/math.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +namespace oidn { + + constexpr float minVectorLength = 1e-10f; + constexpr float minVectorLengthSqr = minVectorLength * minVectorLength; + + using std::log; + using std::log2; + using std::exp; + using std::exp2; + using std::pow; + using std::isfinite; + using std::isnan; + + __forceinline float sqr(float x) + { + return x * x; + } + + __forceinline float rcp(float x) + { + __m128 r = _mm_rcp_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x)))); + } + + __forceinline float rsqrt(float x) + { + __m128 r = _mm_rsqrt_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r), + _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r)))); + } + + __forceinline float maxSafe(float value, float minValue) + { + return isfinite(value) ? max(value, minValue) : minValue; + } + + __forceinline float clampSafe(float value, float minValue, float maxValue) + { + return isfinite(value) ? clamp(value, minValue, maxValue) : minValue; + } + + // Returns ceil(a / b) for non-negative integers + template + __forceinline constexpr Int ceilDiv(Int a, Int b) + { + //assert(a >= 0); + //assert(b > 0); + return (a + b - 1) / b; + } + + // Returns a rounded up to multiple of b + template + __forceinline constexpr Int roundUp(Int a, Int b) + { + return ceilDiv(a, b) * b; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.cpp b/thirdparty/oidn/core/network.cpp new file mode 100644 index 0000000000..4da32073cd --- /dev/null +++ b/thirdparty/oidn/core/network.cpp @@ -0,0 +1,434 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "network.h" +#include "upsample.h" +#include "weights_reorder.h" +#include + +namespace oidn { + + template + Network::Network(const Ref& device, const std::map& weightMap) + : device(device), + eng(engine::cpu, 0), + sm(eng), + weightMap(weightMap) + { + } + + template + void Network::execute(const Progress& progress, int taskIndex) + { + if (progress.func) + { + const double value = double(taskIndex) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + + for (size_t i = 0; i < nodes.size(); ++i) + { + nodes[i]->execute(sm); + + if (progress.func) + { + const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + } + } + + template + std::shared_ptr Network::allocTensor(const memory::dims& dims, + memory::format_tag format, + void* data) + { + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + if (data == nullptr) + { + const size_t bytes = getTensorSize(dims) * sizeof(float); + if (format == BlockedFormat::nChwKc) + activationAllocBytes += bytes; + totalAllocBytes += bytes; + + return std::make_shared(desc, eng); + } + else + { + return std::make_shared(desc, eng, data); + } + } + + template + std::shared_ptr Network::castTensor(const memory::dims& dims, + const std::shared_ptr& src, + size_t srcOffset, + memory::format_tag format) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(srcDesc.data_type == memory::data_type::f32); + assert(getTensorSize(src) >= srcOffset + getTensorSize(dims)); + + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + float* srcPtr = (float*)src->get_data_handle() + srcOffset; + return std::make_shared(desc, eng, srcPtr); + } + + template + std::shared_ptr Network::castTensor(const memory::dims& dims, + const std::shared_ptr& src, + const memory::dims& srcOffset) + { + return castTensor(dims, src, getTensorSize(srcOffset)); + } + + template + void Network::zeroTensor(const std::shared_ptr& dst) + { + assert(getTensorType(dst) == memory::data_type::f32); + memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float)); + } + + template + memory::dims Network::getInputReorderDims(const memory::dims& srcDims, int alignment) + { + memory::dims dstDims = srcDims; + dstDims[1] = getPadded(srcDims[1]); // round up C + dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H + dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W + return dstDims; + } + + template + std::shared_ptr Network::addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& transferFunc, + int alignment, + const std::shared_ptr& userDst) + { + assert(color); + int inputC = 3; + if (albedo) inputC += 3; + if (normal) inputC += 3; + + memory::dims srcDims = {1, inputC, color.height, color.width}; + memory::dims dstDims = getInputReorderDims(srcDims, alignment); + + // Allocate padded memory + auto dst = userDst; + if (!dst) + dst = allocTensor(dstDims); + + // Push node + std::shared_ptr node; + + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(color, albedo, normal, dst, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template + std::shared_ptr Network::addOutputReorder(const std::shared_ptr& src, + const std::shared_ptr& transferFunc, + const Image& output) + { + memory::dims srcDims = getTensorDims(src); + assert(srcDims[1] == K); + + // Push node + std::shared_ptr node; + + if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast(transferFunc)) + node = std::make_shared>(src, output, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getConvDims(const std::string& name, const memory::dims& srcDims) + { + auto b = weightMap[name + "/b"]; + memory::dims dstDims = srcDims; + dstDims[1] = getPadded(b.dims[0]); // dstDims[C] = getPadded(OC) + return dstDims; + } + + template + std::shared_ptr Network::addConv(const std::string& name, + const std::shared_ptr& src, + const std::shared_ptr& userDst, + bool relu) + { + const memory::dims strides = {1, 1}; + const memory::dims padding = {1, 1}; + + memory::dims srcDims = getTensorDims(src); + + // Get the weights + const auto& W = weightMap[name + "/W"]; + if (W.ndims() != 4 || W.format != "oihw") + throw Exception(Error::InvalidOperation, "invalid convolution weights"); + memory::dims weightsDims = W.dims; + auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data); + + // Pad the weights + memory::dims weightsPadDims = weightsDims; + weightsPadDims[1] = getPadded(weightsDims[1]); // IC + weightsPadDims[0] = getPadded(weightsDims[0]); // OC + assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC] + auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw); + WeightsReorderNode(userWeights, weightsPad).execute(sm); + + // Get the biases + const auto& b = weightMap[name + "/b"]; + if (b.ndims() != 1) + throw Exception(Error::InvalidOperation, "invalid convolution biases"); + memory::dims biasDims = b.dims; + + // Copy/pad the biases + memory::dims biasPadDims = {getPadded(biasDims[0])}; + auto bias = allocTensor(biasPadDims); + if (biasDims[0] != biasPadDims[0]) + memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float)); + memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float)); + + // Allocate memory for destination + memory::dims dstDims = srcDims; + dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC] + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create a convolution + // Let the convolution primitive choose the weights format + auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any); + + auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct; + auto convDesc = convolution_forward::desc( + prop_kind::forward_inference, convAlgo, + src->get_desc(), + weightsDesc, + bias->get_desc(), + dst->get_desc(), + strides, padding, padding, padding_kind::zero); + + // Incorporate relu + mkldnn::primitive_attr convAttr; + if (relu) + { + mkldnn::post_ops ops; + ops.append_eltwise( + 1.f, // scale factor, not used + algorithm::eltwise_relu, + 0.f, // max with + 0.f // unused + ); + convAttr.set_post_ops(ops); + } + convAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng); + + // Reorder the weights to the final format, if necessary + auto weights = weightsPad; + if (convPrimDesc.weights_desc() != weightsPad->get_desc()) + { + weights = std::make_shared(convPrimDesc.weights_desc(), eng); + ReorderNode(weightsPad, weights).execute(sm); + } + + // Create convolution node and add it to the net + auto node = std::make_shared(convPrimDesc, src, weights, bias, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getPoolDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] /= 2; // H/2 + dstDims[3] /= 2; // W/2 + return dstDims; + } + + template + std::shared_ptr Network::addPool(const std::shared_ptr& src, + const std::shared_ptr& userDst) + { + const memory::dims kernel = {2, 2}; + const memory::dims strides = {2, 2}; + const memory::dims padding = {0, 0}; + + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getPoolDims(srcDims); + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + auto poolDesc = pooling_forward::desc( + prop_kind::forward_inference, pooling_max, + src->get_desc(), + dst->get_desc(), + strides, kernel, padding, padding, padding_kind::zero); + + mkldnn::primitive_attr poolAttr; + poolAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng); + + auto node = std::make_shared(poolPrimDesc, src, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getUpsampleDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] *= 2; // H*2 + dstDims[3] *= 2; // W*2 + return dstDims; + } + + template + std::shared_ptr Network::addUpsample(const std::shared_ptr& src, + const std::shared_ptr& userDst) + { + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getUpsampleDims(srcDims); + + std::shared_ptr dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create upsampling node and add it to net + auto node = std::make_shared>(src, dst); + nodes.push_back(node); + return node; + } + + template + memory::dims Network::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims) + { + assert(src1Dims[0] == src2Dims[0]); // N + assert(src1Dims[2] == src2Dims[2]); // H + assert(src1Dims[3] == src2Dims[3]); // W + + memory::dims dstDims = src1Dims; + dstDims[1] += src2Dims[1]; // C + return dstDims; + } + + template + std::shared_ptr Network::addAutoexposure(const Image& color, + const std::shared_ptr& transferFunc) + { + auto node = std::make_shared(color, transferFunc); + nodes.push_back(node); + return node; + } + + template + void Network::finalize() + { + // Compute the size of the scratchpad + size_t scratchpadSize = 0; + for (const auto& node : nodes) + scratchpadSize = max(scratchpadSize, node->getScratchpadSize()); + + // Allocate the scratchpad + memory::dims scratchpadDims = { memory::dim(scratchpadSize) }; + memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x); + auto scratchpad = std::make_shared(scratchpadDesc, eng); + activationAllocBytes += scratchpadSize; + totalAllocBytes += scratchpadSize; + + // Set the scratchpad for the nodes + for (auto& node : nodes) + node->setScratchpad(scratchpad); + + // Free the weights + weightMap.clear(); + + // Print statistics + if (device->isVerbose(2)) + { + std::cout << "Activation bytes: " << activationAllocBytes << std::endl; + std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl; + std::cout << "Total bytes : " << totalAllocBytes << std::endl; + } + } + + template class Network<8>; + template class Network<16>; + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h new file mode 100644 index 0000000000..7a696fd355 --- /dev/null +++ b/thirdparty/oidn/core/network.h @@ -0,0 +1,112 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "common/tensor.h" +#include "image.h" +#include "node.h" +#include "input_reorder.h" +#include "output_reorder.h" +#include "transfer_function.h" + +#pragma once + +namespace oidn { + + // Progress state + struct Progress + { + ProgressMonitorFunction func; + void* userPtr; + int taskCount; + }; + + class Executable + { + public: + virtual ~Executable() {} + virtual void execute(const Progress& progress, int taskIndex) = 0; + }; + + template + class Network : public Executable + { + public: + Network(const Ref& device, const std::map& weightMap); + + void execute(const Progress& progress, int taskIndex) override; + + std::shared_ptr allocTensor(const memory::dims& dims, + memory::format_tag format = memory::format_tag::any, + void* data = nullptr); + + std::shared_ptr castTensor(const memory::dims& dims, + const std::shared_ptr& src, + size_t srcOffset = 0, + memory::format_tag format = memory::format_tag::any); + + std::shared_ptr castTensor(const memory::dims& dims, + const std::shared_ptr& src, + const memory::dims& srcOffset); + + void zeroTensor(const std::shared_ptr& dst); + + memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); + + std::shared_ptr addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr& transferFunc, + int alignment, + const std::shared_ptr& userDst = nullptr); + + std::shared_ptr addOutputReorder(const std::shared_ptr& src, + const std::shared_ptr& transferFunc, + const Image& output); + + memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); + std::shared_ptr addConv(const std::string& name, + const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr, + bool relu = true); + + memory::dims getPoolDims(const memory::dims& srcDims); + std::shared_ptr addPool(const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr); + + memory::dims getUpsampleDims(const memory::dims& srcDims); + std::shared_ptr addUpsample(const std::shared_ptr& src, + const std::shared_ptr& userDst = nullptr); + + memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); + + std::shared_ptr addAutoexposure(const Image& color, + const std::shared_ptr& transferFunc); + + void finalize(); + + private: + Ref device; + engine eng; + stream sm; + std::vector> nodes; + std::map weightMap; + + // Memory allocation statistics + size_t activationAllocBytes = 0; // number of allocated activation bytes + size_t totalAllocBytes = 0; // total number of allocated bytes + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h new file mode 100644 index 0000000000..b9ffe906df --- /dev/null +++ b/thirdparty/oidn/core/node.h @@ -0,0 +1,142 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include + +namespace oidn { + + class Node + { + public: + virtual ~Node() = default; + + virtual void execute(stream& sm) = 0; + + virtual std::shared_ptr getDst() const { return nullptr; } + + virtual size_t getScratchpadSize() const { return 0; } + virtual void setScratchpad(const std::shared_ptr& mem) {} + + virtual void setTile(int h1, int w1, int h2, int w2, int H, int W) + { + assert(0); // not supported + } + }; + + // Node wrapping an MKL-DNN primitive + class MklNode : public Node + { + private: + primitive prim; + std::unordered_map args; + std::shared_ptr scratchpad; + + public: + MklNode(const primitive& prim, const std::unordered_map& args) + : prim(prim), + args(args) + {} + + size_t getScratchpadSize() const override + { + const auto primDesc = prim.get_primitive_desc(); + const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0); + if (scratchpadDesc == nullptr) + return 0; + return mkldnn_memory_desc_get_size(scratchpadDesc); + } + + void setScratchpad(const std::shared_ptr& mem) override + { + scratchpad = mem; + args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad)); + } + + void execute(stream& sm) override + { + prim.execute(sm, args); + } + }; + + // Convolution node + class ConvNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr weights; + std::shared_ptr bias; + std::shared_ptr dst; + + public: + ConvNode(const convolution_forward::primitive_desc& desc, + const std::shared_ptr& src, + const std::shared_ptr& weights, + const std::shared_ptr& bias, + const std::shared_ptr& dst) + : MklNode(convolution_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_WEIGHTS, *weights }, + { MKLDNN_ARG_BIAS, *bias }, + { MKLDNN_ARG_DST, *dst } }), + src(src), weights(weights), bias(bias), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + + // Pooling node + class PoolNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + PoolNode(const pooling_forward::primitive_desc& desc, + const std::shared_ptr& src, + const std::shared_ptr& dst) + : MklNode(pooling_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + + // Reorder node + class ReorderNode : public MklNode + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + ReorderNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : MklNode(reorder(reorder::primitive_desc(*src, *dst)), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/output_reorder.h b/thirdparty/oidn/core/output_reorder.h new file mode 100644 index 0000000000..7918d48e15 --- /dev/null +++ b/thirdparty/oidn/core/output_reorder.h @@ -0,0 +1,126 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Output reorder node + template + class OutputReorderNode : public Node + { + private: + // Source + std::shared_ptr src; + const float* srcPtr; + int H1; + int W1; + + // Destination + Image output; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr transferFunc; + + public: + OutputReorderNode(const std::shared_ptr& src, + const Image& output, + const std::shared_ptr& transferFunc) + : src(src), + output(output), + h1Begin(0), w1Begin(0), + h2Begin(0), w2Begin(0), + H(output.height), W(output.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(srcDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + // We assume output data is <= K OC + assert(srcDesc.dims[1] == K); + + srcPtr = (float*)src->get_data_handle(); + H1 = srcDesc.dims[2]; + W1 = srcDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(h1Begin + H <= H1); + assert(w1Begin + W <= W1); + assert(h2Begin + H <= output.height); + assert(w2Begin + W <= output.width); + + const int C1 = K; + + parallel_nd(H, [&](int h) + { + const int h1 = h + h1Begin; + const int h2 = h + h2Begin; + + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + float* dstPtr_C = (float*)output.get(h2, w2); + + // Source is in nChwKc format. In this case C is 1 so this is really nhwc + const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1; + + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = srcPtr_C[i]; + + // The CNN output may contain negative values or even NaNs, so it must be sanitized + x = maxSafe(x, 0.f); + + // Apply the inverse transfer function + x = transferFunc->inverse(x); + + // Sanitize and store the final value + dstPtr_C[i] = max(x, 0.f); + } + } + }); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.cpp b/thirdparty/oidn/core/transfer_function.cpp new file mode 100644 index 0000000000..a33e3c84bc --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.cpp @@ -0,0 +1,95 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "transfer_function.h" + +namespace oidn { + + const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f); + const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale); + + float AutoexposureNode::autoexposure(const Image& color) + { + assert(color.format == Format::Float3); + return 1.0f; + + /*constexpr float key = 0.18f; + constexpr float eps = 1e-8f; + constexpr int K = 16; // downsampling amount + + // Downsample the image to minimize sensitivity to noise + const int H = color.height; // original height + const int W = color.width; // original width + const int HK = (H + K/2) / K; // downsampled height + const int WK = (W + K/2) / K; // downsampled width + + // Compute the average log luminance of the downsampled image + using Sum = std::pair; + + Sum sum = + tbb::parallel_reduce( + tbb::blocked_range2d(0, HK, 0, WK), + Sum(0.f, 0), + [&](const tbb::blocked_range2d& r, Sum sum) -> Sum + { + // Iterate over blocks + for (int i = r.rows().begin(); i != r.rows().end(); ++i) + { + for (int j = r.cols().begin(); j != r.cols().end(); ++j) + { + // Compute the average luminance in the current block + const int beginH = int(ptrdiff_t(i) * H / HK); + const int beginW = int(ptrdiff_t(j) * W / WK); + const int endH = int(ptrdiff_t(i+1) * H / HK); + const int endW = int(ptrdiff_t(j+1) * W / WK); + + float L = 0.f; + + for (int h = beginH; h < endH; ++h) + { + for (int w = beginW; w < endW; ++w) + { + const float* rgb = (const float*)color.get(h, w); + + const float r = maxSafe(rgb[0], 0.f); + const float g = maxSafe(rgb[1], 0.f); + const float b = maxSafe(rgb[2], 0.f); + + L += luminance(r, g, b); + } + } + + L /= (endH - beginH) * (endW - beginW); + + // Accumulate the log luminance + if (L > eps) + { + sum.first += log2(L); + sum.second++; + } + } + } + + return sum; + }, + [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, + tbb::static_partitioner() + ); + + return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;*/ + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.h b/thirdparty/oidn/core/transfer_function.h new file mode 100644 index 0000000000..35f2833092 --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.h @@ -0,0 +1,201 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "image.h" +#include "node.h" + +namespace oidn { + + __forceinline float luminance(float r, float g, float b) + { + return 0.212671f * r + 0.715160f * g + 0.072169f * b; + } + + // Color transfer function base class + class TransferFunction + { + public: + virtual ~TransferFunction() = default; + + virtual float forward(float y) const = 0; + virtual float inverse(float x) const = 0; + }; + + // HDR transfer function base class + class HDRTransferFunction : public TransferFunction + { + protected: + static constexpr float yMax = 65504.f; + + float exposure; + float rcpExposure; + + public: + HDRTransferFunction(float exposure = 1.f) + { + setExposure(exposure); + } + + void setExposure(float exposure) + { + this->exposure = exposure; + this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f; + } + }; + + // Linear transfer function (LDR) + class LinearTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(y, 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(x, 1.f); + } + }; + + // 2.2 gamma transfer function (LDR) + class GammaTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(pow(y, 1.f/2.2f), 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(pow(x, 2.2f), 1.f); + } + }; + + // Logarithmic transfer function (HDR) + // Compresses [0..65504] to [0..1] + class LogTransferFunction : public HDRTransferFunction + { + private: + static const float xScale; + + public: + LogTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return log(y * exposure + 1.f) * xScale; + } + + __forceinline float inverse(float x) const override + { + return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure; + } + }; + + // PQX transfer function (HDR) + // Compresses [0..65504] to [0..1] + class PQXTransferFunction : public HDRTransferFunction + { + private: + static constexpr float m1 = 2610.f / 4096.f / 4.f; + static constexpr float m2 = 2523.f / 4096.f * 128.f; + static constexpr float c1 = 3424.f / 4096.f; + static constexpr float c2 = 2413.f / 4096.f * 32.f; + static constexpr float c3 = 2392.f / 4096.f * 32.f; + static constexpr float a = 3711.f / 4096.f / 8.f; + + static constexpr float yScale = 100.f / 10000.f; + static const float xScale; + + public: + PQXTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return pqxForward(y * exposure * yScale) * xScale; + } + + __forceinline float inverse(float x) const override + { + return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure; + } + + private: + static __forceinline float pqForward(float y) + { + const float yp = pow(y, m1); + return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2); + } + + static __forceinline float pqxForward(float y) + { + if (y <= 1.f) + return pqForward(y); + else + return a * log(y) + 1.f; + } + + static __forceinline float pqInverse(float x) + { + const float xp = pow(x, 1.f/m2); + return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1); + } + + static __forceinline float pqxInverse(float x) + { + if (x <= 1.f) + return pqInverse(x); + else + return exp((x - 1.f) * (1.f/a)); + } + }; + + // Autoexposure node + class AutoexposureNode : public Node + { + private: + Image color; + std::shared_ptr transferFunc; + + public: + AutoexposureNode(const Image& color, + const std::shared_ptr& transferFunc) + : color(color), + transferFunc(transferFunc) + {} + + void execute(stream& sm) override + { + const float exposure = autoexposure(color); + //printf("exposure = %f\n", exposure); + transferFunc->setExposure(exposure); + } + + private: + static float autoexposure(const Image& color); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/upsample.h b/thirdparty/oidn/core/upsample.h new file mode 100644 index 0000000000..f6cace44cd --- /dev/null +++ b/thirdparty/oidn/core/upsample.h @@ -0,0 +1,92 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // 2x2 nearest-neighbor upsampling node + template + class UpsampleNode : public Node + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + UpsampleNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + assert(dstDesc.dims[0] == 1); + // 2x2 upsampling + assert(dstDesc.dims[2] == srcDesc.dims[2] * 2); + assert(dstDesc.dims[3] == srcDesc.dims[3] * 2); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int C = srcDesc.dims[1]; + const int H = srcDesc.dims[2]; + const int W = srcDesc.dims[3]; + const int CK = C / K; + + parallel_nd(CK, H, [&](int ck, int h) + { + const size_t offset = ck*H*W*K + h*W*K; + const float* srcPtr_line = srcPtr + offset; + float* dstPtr_line0 = dstPtr + offset * 4; + float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line + + for (int w = 0; w < W; ++w) + { + #pragma unroll + for (int k = 0; k < K; k += 4) + { + const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]); + + _mm_stream_ps(&dstPtr_line0[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m); + } + } + }); + } + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/weights_reorder.h b/thirdparty/oidn/core/weights_reorder.h new file mode 100644 index 0000000000..6c5dacb8aa --- /dev/null +++ b/thirdparty/oidn/core/weights_reorder.h @@ -0,0 +1,99 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // Reorders weights from oihw to padded oihw format + template + class WeightsReorderNode : public Node + { + private: + std::shared_ptr src; + std::shared_ptr dst; + + public: + WeightsReorderNode(const std::shared_ptr& src, + const std::shared_ptr& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(getPadded(srcDesc.dims[0]) == dstDesc.dims[0]); // OC + assert(getPadded(srcDesc.dims[1]) == dstDesc.dims[1]); // IC + assert(srcDesc.dims[2] == dstDesc.dims[2]); + assert(srcDesc.dims[3] == dstDesc.dims[3]); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int OC1 = srcDesc.dims[0]; + const int OC2 = dstDesc.dims[0]; + const int IC1 = srcDesc.dims[1]; + const int IC2 = dstDesc.dims[1]; + const int H = dstDesc.dims[2]; + const int W = dstDesc.dims[3]; + + for (int oc = 0; oc < OC2; ++oc) + { + for (int ic = 0; ic < IC2; ++ic) + { + for (int h = 0; h < H; ++h) + { + for (int w = 0; w < W; ++w) + { + // Output is in oihw format + float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w; + + if (oc < OC1 && ic < IC1) + { + // Input is in oihw format + const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w; + *dstPtr_c = *srcPtr_c; + } + else + { + // padding + *dstPtr_c = 0; + } + } + } + } + } + } + + std::shared_ptr getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.h b/thirdparty/oidn/include/OpenImageDenoise/oidn.h new file mode 100644 index 0000000000..57ba6baa21 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.h @@ -0,0 +1,214 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include +#include + +#include "version.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#ifndef OIDN_API +#if defined(_WIN32) && !defined(OIDN_STATIC_LIB) +# define OIDN_API __declspec(dllimport) +#else +# define OIDN_API +#endif +#endif + +// ---------------------------------------------------------------------------- +// Device +// ---------------------------------------------------------------------------- + +// Device types +typedef enum +{ + OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically + + OIDN_DEVICE_TYPE_CPU = 1, // CPU device +} OIDNDeviceType; + +// Error codes +typedef enum +{ + OIDN_ERROR_NONE = 0, // no error occurred + OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred + OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified + OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed + OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation + OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported + OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user +} OIDNError; + +// Error callback function +typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message); + +// Device handle +typedef struct OIDNDeviceImpl* OIDNDevice; + +// Creates a new device. +OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type); + +// Retains the device (increments the reference count). +OIDN_API void oidnRetainDevice(OIDNDevice device); + +// Releases the device (decrements the reference count). +OIDN_API void oidnReleaseDevice(OIDNDevice device); + +// Sets a boolean parameter of the device. +OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value); + +// Sets an integer parameter of the device. +OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value); + +// Gets a boolean parameter of the device. +OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name); + +// Gets an integer parameter of the device (e.g. "version"). +OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name); + +// Sets the error callback function of the device. +OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr); + +// Returns the first unqueried error code stored in the device for the current +// thread, optionally also returning a string message (if not NULL), and clears +// the stored error. Can be called with a NULL device as well to check why a +// device creation failed. +OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage); + +// Commits all previous changes to the device. +// Must be called before first using the device (e.g. creating filters). +OIDN_API void oidnCommitDevice(OIDNDevice device); + +// ---------------------------------------------------------------------------- +// Buffer +// ---------------------------------------------------------------------------- + +// Formats for images and other data stored in buffers +typedef enum +{ + OIDN_FORMAT_UNDEFINED = 0, + + // 32-bit single-precision floating point scalar and vector formats + OIDN_FORMAT_FLOAT = 1, + OIDN_FORMAT_FLOAT2 = 2, + OIDN_FORMAT_FLOAT3 = 3, + OIDN_FORMAT_FLOAT4 = 4, +} OIDNFormat; + +// Access modes for mapping buffers +typedef enum +{ + OIDN_ACCESS_READ = 0, // read-only access + OIDN_ACCESS_WRITE = 1, // write-only access + OIDN_ACCESS_READ_WRITE = 2, // read and write access + OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded +} OIDNAccess; + +// Buffer handle +typedef struct OIDNBufferImpl* OIDNBuffer; + +// Creates a new buffer (data allocated and owned by the device). +OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize); + +// Creates a new shared buffer (data allocated and owned by the user). +OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize); + +// Maps a region of the buffer to host memory. +// If byteSize is 0, the maximum available amount of memory will be mapped. +OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize); + +// Unmaps a region of the buffer. +// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer. +OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr); + +// Retains the buffer (increments the reference count). +OIDN_API void oidnRetainBuffer(OIDNBuffer buffer); + +// Releases the buffer (decrements the reference count). +OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer); + +// ---------------------------------------------------------------------------- +// Filter +// ---------------------------------------------------------------------------- + +// Progress monitor callback function +typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n); + +// Filter handle +typedef struct OIDNFilterImpl* OIDNFilter; + +// Creates a new filter of the specified type (e.g. "RT"). +OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type); + +// Retains the filter (increments the reference count). +OIDN_API void oidnRetainFilter(OIDNFilter filter); + +// Releases the filter (decrements the reference count). +OIDN_API void oidnReleaseFilter(OIDNFilter filter); + +// Sets an image parameter of the filter (stored in a buffer). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name, + OIDNBuffer buffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets an image parameter of the filter (owned by the user). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets a boolean parameter of the filter. +OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value); + +// Gets a boolean parameter of the filter. +OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name); + +// Sets an integer parameter of the filter. +OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value); + +// Gets an integer parameter of the filter. +OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name); + +// Sets a float parameter of the filter. +OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value); + +// Gets a float parameter of the filter. +OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name); + +// Sets the progress monitor callback function of the filter. +OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr); + +// Commits all previous changes to the filter. +// Must be called before first executing the filter. +OIDN_API void oidnCommitFilter(OIDNFilter filter); + +// Executes the filter. +OIDN_API void oidnExecuteFilter(OIDNFilter filter); + +#if defined(__cplusplus) +} +#endif diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp new file mode 100644 index 0000000000..9f95a56fe1 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp @@ -0,0 +1,468 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include +#include "oidn.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // Buffer + // -------------------------------------------------------------------------- + + // Formats for images and other data stored in buffers + enum class Format + { + Undefined = OIDN_FORMAT_UNDEFINED, + + // 32-bit single-precision floating point scalar and vector formats + Float = OIDN_FORMAT_FLOAT, + Float2 = OIDN_FORMAT_FLOAT2, + Float3 = OIDN_FORMAT_FLOAT3, + Float4 = OIDN_FORMAT_FLOAT4, + }; + + // Access modes for mapping buffers + enum class Access + { + Read = OIDN_ACCESS_READ, // read-only access + Write = OIDN_ACCESS_WRITE, // write-only access + ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access + WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded + }; + + // Buffer object with automatic reference counting + class BufferRef + { + private: + OIDNBuffer handle; + + public: + BufferRef() : handle(nullptr) {} + BufferRef(OIDNBuffer handle) : handle(handle) {} + + BufferRef(const BufferRef& other) : handle(other.handle) + { + if (handle) + oidnRetainBuffer(handle); + } + + BufferRef(BufferRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + BufferRef& operator =(const BufferRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainBuffer(other.handle); + if (handle) + oidnReleaseBuffer(handle); + handle = other.handle; + } + return *this; + } + + BufferRef& operator =(BufferRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + BufferRef& operator =(OIDNBuffer other) + { + if (other) + oidnRetainBuffer(other); + if (handle) + oidnReleaseBuffer(handle); + handle = other; + return *this; + } + + ~BufferRef() + { + if (handle) + oidnReleaseBuffer(handle); + } + + OIDNBuffer getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Maps a region of the buffer to host memory. + // If byteSize is 0, the maximum available amount of memory will be mapped. + void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0) + { + return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize); + } + + // Unmaps a region of the buffer. + // mappedPtr must be a pointer returned by a previous call to map. + void unmap(void* mappedPtr) + { + oidnUnmapBuffer(handle, mappedPtr); + } + }; + + // -------------------------------------------------------------------------- + // Filter + // -------------------------------------------------------------------------- + + // Progress monitor callback function + typedef bool (*ProgressMonitorFunction)(void* userPtr, double n); + + // Filter object with automatic reference counting + class FilterRef + { + private: + OIDNFilter handle; + + public: + FilterRef() : handle(nullptr) {} + FilterRef(OIDNFilter handle) : handle(handle) {} + + FilterRef(const FilterRef& other) : handle(other.handle) + { + if (handle) + oidnRetainFilter(handle); + } + + FilterRef(FilterRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + FilterRef& operator =(const FilterRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainFilter(other.handle); + if (handle) + oidnReleaseFilter(handle); + handle = other.handle; + } + return *this; + } + + FilterRef& operator =(FilterRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + FilterRef& operator =(OIDNFilter other) + { + if (other) + oidnRetainFilter(other); + if (handle) + oidnReleaseFilter(handle); + handle = other; + return *this; + } + + ~FilterRef() + { + if (handle) + oidnReleaseFilter(handle); + } + + OIDNFilter getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets an image parameter of the filter (stored in a buffer). + void setImage(const char* name, + const BufferRef& buffer, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetFilterImage(handle, name, + buffer.getHandle(), (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets an image parameter of the filter (owned by the user). + void setImage(const char* name, + void* ptr, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetSharedFilterImage(handle, name, + ptr, (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets a boolean parameter of the filter. + void set(const char* name, bool value) + { + oidnSetFilter1b(handle, name, value); + } + + // Sets an integer parameter of the filter. + void set(const char* name, int value) + { + oidnSetFilter1i(handle, name, value); + } + + // Sets a float parameter of the filter. + void set(const char* name, float value) + { + oidnSetFilter1f(handle, name, value); + } + + // Gets a parameter of the filter. + template + T get(const char* name); + + // Sets the progress monitor callback function of the filter. + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr) + { + oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr); + } + + // Commits all previous changes to the filter. + void commit() + { + oidnCommitFilter(handle); + } + + // Executes the filter. + void execute() + { + oidnExecuteFilter(handle); + } + }; + + // Gets a boolean parameter of the filter. + template<> + inline bool FilterRef::get(const char* name) + { + return oidnGetFilter1b(handle, name); + } + + // Gets an integer parameter of the filter. + template<> + inline int FilterRef::get(const char* name) + { + return oidnGetFilter1i(handle, name); + } + + // Gets a float parameter of the filter. + template<> + inline float FilterRef::get(const char* name) + { + return oidnGetFilter1f(handle, name); + } + + // -------------------------------------------------------------------------- + // Device + // -------------------------------------------------------------------------- + + // Device types + enum class DeviceType + { + Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically + + CPU = OIDN_DEVICE_TYPE_CPU, // CPU device + }; + + // Error codes + enum class Error + { + None = OIDN_ERROR_NONE, // no error occurred + Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred + InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified + InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed + OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation + UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported + Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user + }; + + // Error callback function + typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message); + + // Device object with automatic reference counting + class DeviceRef + { + private: + OIDNDevice handle; + + public: + DeviceRef() : handle(nullptr) {} + DeviceRef(OIDNDevice handle) : handle(handle) {} + + DeviceRef(const DeviceRef& other) : handle(other.handle) + { + if (handle) + oidnRetainDevice(handle); + } + + DeviceRef(DeviceRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + DeviceRef& operator =(const DeviceRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainDevice(other.handle); + if (handle) + oidnReleaseDevice(handle); + handle = other.handle; + } + return *this; + } + + DeviceRef& operator =(DeviceRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + DeviceRef& operator =(OIDNDevice other) + { + if (other) + oidnRetainDevice(other); + if (handle) + oidnReleaseDevice(handle); + handle = other; + return *this; + } + + ~DeviceRef() + { + if (handle) + oidnReleaseDevice(handle); + } + + OIDNDevice getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets a boolean parameter of the device. + void set(const char* name, bool value) + { + oidnSetDevice1b(handle, name, value); + } + + // Sets an integer parameter of the device. + void set(const char* name, int value) + { + oidnSetDevice1i(handle, name, value); + } + + // Gets a parameter of the device. + template + T get(const char* name); + + // Sets the error callback function of the device. + void setErrorFunction(ErrorFunction func, void* userPtr = nullptr) + { + oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr); + } + + // Returns the first unqueried error code and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError() + { + return (Error)oidnGetDeviceError(handle, nullptr); + } + + // Returns the first unqueried error code and string message, and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError(const char*& outMessage) + { + return (Error)oidnGetDeviceError(handle, &outMessage); + } + + // Commits all previous changes to the device. + // Must be called before first using the device (e.g. creating filters). + void commit() + { + oidnCommitDevice(handle); + } + + // Creates a new buffer (data allocated and owned by the device). + BufferRef newBuffer(size_t byteSize) + { + return oidnNewBuffer(handle, byteSize); + } + + // Creates a new shared buffer (data allocated and owned by the user). + BufferRef newBuffer(void* ptr, size_t byteSize) + { + return oidnNewSharedBuffer(handle, ptr, byteSize); + } + + // Creates a new filter of the specified type (e.g. "RT"). + FilterRef newFilter(const char* type) + { + return oidnNewFilter(handle, type); + } + }; + + // Gets a boolean parameter of the device. + template<> + inline bool DeviceRef::get(const char* name) + { + return oidnGetDevice1b(handle, name); + } + + // Gets an integer parameter of the device (e.g. "version"). + template<> + inline int DeviceRef::get(const char* name) + { + return oidnGetDevice1i(handle, name); + } + + // Creates a new device. + inline DeviceRef newDevice(DeviceType type = DeviceType::Default) + { + return DeviceRef(oidnNewDevice((OIDNDeviceType)type)); + } + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/version.h b/thirdparty/oidn/include/OpenImageDenoise/version.h new file mode 100644 index 0000000000..66b347c992 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/version.h @@ -0,0 +1,23 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#define OIDN_VERSION_MAJOR 1 +#define OIDN_VERSION_MINOR 1 +#define OIDN_VERSION_PATCH 0 +#define OIDN_VERSION 10100 +#define OIDN_VERSION_STRING "1.1.0" diff --git a/thirdparty/oidn/mkl-dnn/LICENSE b/thirdparty/oidn/mkl-dnn/LICENSE new file mode 100644 index 0000000000..d13f7b7ca0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/LICENSE @@ -0,0 +1,214 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ============================================================================ + + Intel MKL-DNN includes components with separate copyright + notices and license terms. + + XByak, 3-clause BSD license + Copyright (c) 2007 MITSUNARI Shigeo + See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT + + gtest, 3-clause BSD license + Copyright 2008, Google Inc. + See full copyright notice and license text in tests/gtests/gtest/LICENSE diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.h b/thirdparty/oidn/mkl-dnn/include/mkldnn.h new file mode 100644 index 0000000000..9b64994922 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.h @@ -0,0 +1,1771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_H +#define MKLDNN_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#include "mkldnn_version.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +/** @addtogroup c_api C API + * @{ */ + +/** @addtogroup c_api_primitive Primitive operations + * @{ */ + +/** @addtogroup c_api_primitive_common Common primitive operations + * @{ */ + +/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr, + * @p engine, and optionally a hint primitive descriptor from forward + * propagation (required for backward propagation). Pass @c NULL for forward + * propagation. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create( + mkldnn_primitive_desc_iterator_t *iterator, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no + * more primitive descriptors are available. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next( + mkldnn_primitive_desc_iterator_t iterator); + +/** Fetches the current primitive descriptor. + * + * @note + * The user should delete the fetched primitive descriptor using + * mkldnn_primitive_desc_destroy() once it is no longer needed. */ +mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch( + const_mkldnn_primitive_desc_iterator_t iterator); + +/** Deletes a primitive descriptor @p iterator */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy( + mkldnn_primitive_desc_iterator_t iterator); + +/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and + * optionally a hint primitive descriptor from forward propagation. The call is + * equivalent to creating a primitive descriptor iterator, immediately fetching + * a primitive descriptor, and then destroying the iterator. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Makes a copy of a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_primitive_desc_t existing_primitive_desc); + +/** Returns a constant reference to the attribute of a @p primitive_desc. + * + * @warning + * The user should not destroy the obtained @p attr. + * + * @warning + * The lifetime of an @p attr is the same as that of a @p primitive_desc, + * so it is illegal to use the @p attr once @p primitive_desc has been + * destroyed. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr( + const_mkldnn_primitive_desc_t primitive_desc, + const_mkldnn_primitive_attr_t *attr); + +/** Deletes a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy( + mkldnn_primitive_desc_t primitive_desc); + +/** Queries primitive descriptor + * + * One of the most typical use cases is to query a convolution primitive + * descriptor created with source, weights, and destination formats equal + * to #mkldnn_format_tag_any about the corresponding memory descriptors + * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and + * #mkldnn_query_dst_md respectively) to be able to prepare memory and + * create reorders if required. + * + * Another quite typical use case is to query an operation primitive + * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md). + * The returned status #mkldnn_not_required indicates that a workspace is + * not required. + * + * A few other possibilities: + * - query an operation primitive descriptor for the underlying operation + * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d, + * #mkldnn_query_rnn_d, etc.) + * - query an operation primitive descriptor for the implementation + * information string (#mkldnn_query_impl_info_str) + * - query an operation primitive descriptor for the number of inputs and + * outputs (#mkldnn_query_num_of_inputs_s32 and + * #mkldnn_query_num_of_outputs_s32 respectively) + * + * @sa mkldnn_query_t for more options + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index, void *result); + +/** Queries primitive descriptor for memory descriptor + * + * @returns NULL in case of any error. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Queries primitive descriptor for signed 32bit int + * + * @returns 0 in case of any error (in particular if the queried entity is + * not of type int32_t). Note that 0 might also be the actual returned + * value. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +int MKLDNN_API mkldnn_primitive_desc_query_s32( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Creates a @p primitive using a @p primitive_desc descriptor. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_create( + mkldnn_primitive_t *primitive, + const_mkldnn_primitive_desc_t primitive_desc); + +/** Executes a @p primitive using a @p stream, and @p nargs arguments + * @p args. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_execute( + const_mkldnn_primitive_t primitive, mkldnn_stream_t stream, + int nargs, const mkldnn_exec_arg_t *args); + +/** Retrieves a reference to the @p primitive_desc descriptor of given @p + * primitive. + * + * @warning + * The returned object must not be destroyed by the user. The @c const + * qualifier of the returned object prevents such attempts. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc( + const_mkldnn_primitive_t primitive, + const_mkldnn_primitive_desc_t *primitive_desc); + +/** Deletes a @p primitive. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy( + mkldnn_primitive_t primitive); + +/** @} */ + +/** @addtogroup c_api_attributes Attributes + * An extension for controlling primitive behavior. + * @{ */ + +/** Creates an empty (default) @p attr attribute. All the parameters are set to + * default values. + * + * An empty attribute is used in primitive descriptor creation whenever it + * is not passed explicitly, e.g. in mkldnn_primitive_desc_create. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create( + mkldnn_primitive_attr_t *attr); + +/** Makes a copy of an @p existing_attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone( + mkldnn_primitive_attr_t *attr, + const_mkldnn_primitive_attr_t existing_attr); + +/** Deletes an @p attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy( + mkldnn_primitive_attr_t attr); + +/** Returns the scratchpad @p mode set in the attribute @p attr */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode( + const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode); + +/** Sets scratchpad @p mode. + * + * The possible values are: #mkldnn_scratchpad_mode_library (default) and + * #mkldnn_scratchpad_mode_user. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode( + mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode); + +/** Returns @p count, correspondence scale @p mask, and a pointer to a constant + * floating point array of output @p scales for given @p attr, previously set + * by mkldnn_primitive_attr_set_output_scales. + * + * @warning + * The @p scales array points to the internal @p attr field, so the user + * should not modify or destroy @p scales. + * + * @warning + * The lifetime of @p scales is the same as that of the @p attr to which it + * belongs, so it is illegal to use @p scales after @p attr is destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales( + const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask, + const float **scales); + +/** Sets output @p scales for primitive operations. The number of elements @p + * count and correspondence scale @p mask are stored for future use. + * + * The @p mask argument defines the correspondence between the output tensor + * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a + * dedicated scaling factor for each slice of the output tensor over the i-th + * dimension. Set @p mask to 0 to use a common scaling factor for the whole + * output tensor. + * + * @note + * The dimension order is always native and does not depend on the actual + * layout used. Examples: + * - 2D dimensional data the order of dimensions is always: (n, c) + * - 4D dimensional data the order is always: (n, c, h, w) + * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw) + * + * Example usage: + * @code + * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params + * float scales[oc] = { ... }; // unique output scales per output channel + * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... + * + * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc + * + * mkldnn_primitive_attr_t attr; + * mkldnn_primitive_attr_create(&attr); // create default attributes + * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales); + * + * mkldnn_primitive_desc_t cpd; + * mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL); + * @endcode + * + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is the user's + * responsibility to set proper values. The following formula must hold: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *scales); + +/** Returns @p post_ops for given @p attr. + * + * @warning + * @p post_ops points to the internal @p attr field, so the user should not + * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the + * same as that of the @p attr it belongs to, so it is illegal to use @p + * post_ops after @p attr has been destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops( + const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops); + +/** Sets configured @p post_ops to an attribute @p attr for future use (when + * primitive descriptor is being created). + * + * @note + * At this point in time, there is no way to check whether the primitive + * descriptor does or does not support a given sequence of post operations. + * Therefore the user should handle an error that might occur at the + * mkldnn_primitive_desc_create call. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops( + mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops); + +/** @addtogroup c_api_attributes_post_ops Sequence of post operations + * An extension for performing extra operations after a base operation. + * @{ */ + +/** Creates an empty sequence of post operations @p post_ops. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops); + +/** Deletes a @p post_ops sequence. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops); + +/** Returns the @p length of post operations for given @p post_ops. */ +int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops); + +/** Returns the type of post operation with index @p index in given + * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */ +mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind( + const_mkldnn_post_ops_t post_ops, int index); + +/** Appends accumulation (sum) post operation to the @p post_ops. Prior to + * accumulating the result, the previous value would be multiplied by @p scale. + * + * The kind of this post operation is #mkldnn_sum. + * + * This feature might improve performance for cases like residual learning + * blocks, where the result of convolution is accumulated to the previously + * computed activations. The parameter @p scale might be extreme for the + * integer-based computations when the result and previous activations have + * different logical scaling factors. + * + * In the simplest case when the accumulation is the only post operation, the + * computations would be: + * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...) + * + * @note + * This post operation (as well as all the others) disregards the original + * layout of the destination; that is, the layout of the original + * destination is expected to be the same as the layout of the stored + * destination. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum( + mkldnn_post_ops_t post_ops, float scale); + +/** Gets the parameters of the accumulation (sum) post operation with index + * @p index in the sequence of @p post_ops. + * + * @note + * If index @p index would not correspond to the accumulation post + * operation, the function returns #mkldnn_invalid_arguments. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum( + const_mkldnn_post_ops_t post_ops, int index, float *scale); + +/** Appends eltwise post operation to the @p post_ops with given parameters + * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and + * mkldnn_eltwise_desc_t). + * + * The kind of this post operation is #mkldnn_eltwise. + * + * In the simplest case when the eltwise is the only post operation, the + * computations would be: + * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...) + * where eltwise_op is configured with the given parameters. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise( + mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, + float alpha, float beta); + +/** Gets the eltwise parameters of the post operation with index @p index in + * the sequence of @p post_ops. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise( + const_mkldnn_post_ops_t post_ops, int index, float *scale, + mkldnn_alg_kind_t *alg, float *alpha, float *beta); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_memory Memory + * A primitive to describe and store data. + * + * The library supports various data types and formats. Memory hierarchy + * consists of three levels of abstraction: + * 1. **Memory descriptor** -- engine agnostic logical description of data + * (number of dimensions, dimensions themselves, and data type), and + * optionally the format/layout that describes the physical representation + * of data in memory. If the format is not known yet, one can pass + * #mkldnn_format_tag_any. This approach is used to allow compute-intensive + * primitives to specify the most appropriate format on their own with + * users required to reorder the data if the incoming format doesn't match + * the primitive's selection. Memory descriptor can be initialized with + * mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides() + * functions, or by directly filling the mkldnn_memory_desc_t structure. + * The latter requires deep knowledge of how the physical data + * representation is mapped to the structure. + * The @ref understanding_memory_formats topic should shed some light on + * that. + * For the fully defined memory descriptors (i.e. where the format kind is + * not equal to #mkldnn_format_kind_any) a user can the size, using the + * mkldnn_memory_desc_get_size() function. As described in + * @ref understanding_memory_formats, the size of data sometimes cannot + * be computed as the product of dimensions times the size of the data + * type. So users are encouraged to use this function for better code + * portability. + * Two memory descriptors can be compared with mkldnn_memory_desc_equal(). + * The comparison is especially useful when checking whether a primitive + * requires reorder from the user's data format to the primitive's format. + * 2. **Memory** -- an engine-specific object that handles the data and its + * description (a memory descriptor). For CPU enigne, the data handle is + * simply a pointer to @c void. The data handle can be queried using + * mkldnn_memory_get_data_handle() and set using + * mkldnn_memory_set_data_handle(). The latter function always sets the + * memory in the padding region to zero, which is the invariant maintained + * by all the primitives in Intel MKL-DNN. + * See @ref understanding_memory_formats for more details. + * A memory can be created using mkldnn_memory_create() function. + * A memory can also be queried for the underlying memory descriptor and + * engine using mkldnn_memory_get_memory_desc() and + * mkldnn_memory_get_engine() functions. + * + * Along with ordinary memory with all dimensions being positive, Intel + * MKL-DNN supports *zero-volume* memory with one or more dimensions set to + * zero. This is to support the NumPy\* convention. + * If a *zero-volume* memory is passed to a primitive, the primitive does + * not perform any computations on this memory. For example: + * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)` + * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)` + * weights would produce `(0 batch, 16 output channels, 11 height, 11 width)` + * destination (assuming strides are `1` and paddings are zero) and perform + * zero multiply-add operations. + * - Concatenation of three memories of shapes `(3, 4, 13, 13)`, + * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce + * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second + * input (however, if the user created a concatenation primitive descriptor + * with three inputs they should also provide all three memories to the + * concatenation primitive, including the one with zero second dimension). + * - However, Intel MKL-DNN would return an error when attempting to create a + * convolution with *zero-volume* memory passed for weights because such a + * convolution is not well-defined: + * ~~~ + * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3) + * ~~~ + * Should the values in the destination be zeroes or just not accessed at + * all? Moreover, backward pass w.r.t. weights in such cases is also not + * well-defined. + * + * Data handle of *zero-volume* memory is never accessed and hence can be + * unset (NULL in case of CPU engine). + * + * @sa @ref understanding_memory_formats + * @{ */ + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and @p strides. + * + * The @p strides might be NULL, which means the order of physical dimensions + * is the same as the order of logical ones. + * + * @note The logical order of dimensions is defined by a primitive that + * consumes the memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, const mkldnn_dims_t strides); + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and format @p tag. + * + * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define + * the appropriate memory format. In this case, the @p format_kind would be set + * to #mkldnn_format_kind_any */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, mkldnn_format_tag_t tag); + +/** Initializes a @p memory_desc for a given @p parent_memory_desc, with + * @p dims sizes and @p offsets. May fail if layout used does not allow + * obtain desired submemory. In this case consider using `extract` or `insert` + * primitive */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory( + mkldnn_memory_desc_t *memory_desc, + const mkldnn_memory_desc_t *parent_memory_desc, + const mkldnn_dims_t dims, const mkldnn_dims_t offsets); + +/** Compares two memory descriptors. + * @return 1 if the descriptors are the same. + * @return 0 if the descriptors are different. + * + * Use this function to identify whether a reorder is required between the + * two memories */ +int MKLDNN_API mkldnn_memory_desc_equal( + const mkldnn_memory_desc_t *lhs, + const mkldnn_memory_desc_t *rhs); + +/** Returns the size (in bytes) that is required for given @p memory_desc */ +size_t MKLDNN_API mkldnn_memory_desc_get_size( + const mkldnn_memory_desc_t *memory_desc); + +/** Creates a memory for given @p memory_desc and @p engine. Also sets handle + * to @p native_handle. + * The @p native_handle can: + * - point to the user allocated memory, i.e. valid handle. In this case the + * library doesn't own allocated memory. + * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and + * attach memory. In this case the library owns allocated memory. + * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory, + const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine, + void *native_handle); + +/** Returns a @p memory_desc associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc( + const_mkldnn_memory_t memory, + const mkldnn_memory_desc_t **memory_desc); + +/** Returns an @p engine associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine( + const_mkldnn_memory_t memory, mkldnn_engine_t *engine); + +/** For a @p memory, returns the data @p handle. + * + * For the CPU engine, the data handle is a pointer to the actual data. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle( + const_mkldnn_memory_t memory, void **handle); + +/** For a @p memory, sets the data @p handle. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle( + mkldnn_memory_t memory, void *handle); + +/** Deletes a @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory); + +/** @} */ + +/** @addtogroup c_api_reorder Reorder + * A primitive to copy data between memory formats. + * @{ */ + +/** Initializes a @p reorder_primitive_desc using the description of the source + * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md) + * memory, and an @p attr attribute. + * + * Inputs: + * - input (#mkldnn_query_src_md, 0) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create( + mkldnn_primitive_desc_t *reorder_primitive_desc, + mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md, + mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md, + const_mkldnn_primitive_attr_t attr); + +/** @} */ + +/** @addtogroup c_api_concat Concat + * A primitive to concatenate data by arbitrary dimension. + * @{ */ + +/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n + * inputs by @p concat_dimension with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - input 0 (#mkldnn_query_src_md, 0) + * - input 1 (#mkldnn_query_src_md, 1) + * - ... + * - input @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create( + mkldnn_primitive_desc_t *concat_primitive_desc, + const mkldnn_memory_desc_t *dst_md, + int n, int concat_dimension, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_sum Sum + * A primitive to sum data. + * @{ */ + +/** Creates out-of-place @p sum_primitive_desc for sum of @p n + * inputs multiplied by scale with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - src 0 (#mkldnn_query_src_md, 0) + * - src 1 (#mkldnn_query_src_md, 1) + * - ... + * - src @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create( + mkldnn_primitive_desc_t *sum_primitive_desc, + const mkldnn_memory_desc_t *dst_mds, + int n, const float *scales, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_convolution Convolution + * A primitive to compute convolution using different algorithms. + * + * \f[dst[n][oc][oh][ow] = + * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw] + * \cdot weights[g][oc][ic][kh][kw] + * + bias[g][oc],\f] + * + * where size of output spatial domain is given by + * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}} + * \right\rfloor + 1\f$, + * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}} + * \right\rfloor + 1\f$, + * + * and summation is carried over input channels \f$ic\f$ in + * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and + * \f$p_l, p_r\f$ are @p padding_l and @p padding_r. + * @{ */ + +/** Initializes a convolution descriptor @p conv_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. In order to create a + * convolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equal to #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * In order to create a dilated convolution without bias, @p bias_desc + * should either be @c NULL or point to a descriptor with memory format kind + * equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API +mkldnn_dilated_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_deconvolution Deconvolution + * A primitive to compute deconvolution using different algorithms. + * + * @{ */ + + +/** Initializes a deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. In order to create a + * deconvolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to + * create a dilated deconvolution without bias, @p bias_desc should either be + * @c NULL or point to a descriptor with memory format kind equal + * #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to weights using @p alg_kind, memory descriptors, + * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_shuffle Shuffle + * A primitive to shuffle data along the axis. + * @{ */ + +/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind, + * memory descriptor @p data_desc, @p axis, and @p group_size. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int axis, + mkldnn_dim_t group_size); + +/** Initializes a @p shuffle_desc for backward propagation using memory + * descriptor @p diff_data_desc, @p axis, and @p group_size. + * + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, + const mkldnn_memory_desc_t *diff_data_desc, int axis, + mkldnn_dim_t group_size); + +/** @} */ + +/** @addtogroup c_api_eltwise Eltwise + * A primitive to compute element-wise operations like parametric rectifier + * linear unit (ReLU). + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * @warning Because the original src is required for backward pass, in-place + * forward pass in general cannot be applied during training. However, for some + * kinds of element-wise operations (namely ReLU with alpha parameter equals 0), + * dst and src can be interchangeable for the backward pass, which enables + * performing in-place forward even for training. + * + * @{ */ + +/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and + * @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + float alpha, float beta); + +/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind + * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the + * @p alpha and @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, float alpha, float beta); + +/** @} */ + +/** @addtogroup c_api_softmax Softmax + * A primitive to perform softmax. + * + * \f[dst[u][c][in] = + * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])} + * {\sum\limits_{c}\{\exp(src[ou][c][in]) + * - \max\limits_{c}(src[ou][c][in])\}},\f] + * + * where \f$ou, iu\f$ are outer and inner sizes repectively, defined + * by @p data_desc.dims and @p softmax_axis. + * @{ */ + +/** Initializes a @p softmax_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference) + * and memory descriptor @p data_desc. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** Initializes a @p softmax_desc for backward propagation using memory + * descriptors @p diff_desc and @p data_desc. + * + * Inputs: + * - dst (#mkldnn_query_dst_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, + const mkldnn_memory_desc_t *diff_desc, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** @} */ + +/** @addtogroup c_api_pooling Pooling + * A primitive to perform max or average pooling. + * + * Max pooling: + * \f[dst[n][oc][oh][ow] = + * \max\limits_{kw,kh} + * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f] + * + * Average pooling: + * \f[dst[n][oc][oh][ow] = + * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f] + * + * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and + * output spatial dimensions are calculated similarly to how they are done in + * convolution. + * + * During training, max pooling requires a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to + * save indices where maximum was found. The workspace layout is opaque, and + * the indices cannot be restored from it. However, one can use backward + * pooling to perform up-sampling (used in some detection topologies). + * + * @{ */ + +/** Initializes a pooling descriptor @p pool_desc for forward propagation using + * @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling + * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l, + * @p padding_r, and @p padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max and + * @p prop_kind = #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a pooling descriptor @p pool_desc for backward propagation + * using @p alg_kind, memory descriptors, and pooling parameters in the spatial + * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p + * padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_lrn LRN + * A primitive to perform local response normalization (LRN) across or within + * channels. + * + * LRN accross channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c+i][h][w])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * LRN within channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c][h+i][w+i])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * where \f$n_{l}\f$ is the @p local_size. + * + * During training, LRN might or might not require a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The + * behavior is implementation specific. Optimized implementations typically + * require a workspace and use it to save some intermediate results from the + * forward pass that accelerate computations on the backward pass. + * + * To check whether a workspace is required, query the LRN primitive descriptor + * for the workspace (#mkldnn_query_workspace_md). Success indicates that the + * workspace is required and its description will be returned. + * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd + * + * @{ */ + +/** Initializes an @p lrn_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind, memory descriptor @p data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + mkldnn_dim_t local_size, float alpha, float beta, float k); + +/** Initializes an @p lrn_desc for backward propagation using @p alg_kind, + * memory descriptors @p data_desc and @p diff_data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size, + float alpha, float beta, float k); + +/** @} */ + +/** @addtogroup c_api_batch_normalization Batch Normalization + * A primitive to perform batch normalization. + * + * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]} + * {\sqrt{\sigma[c] + eps}} + \beta[c],\f] + * + * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and, + * + * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$, + * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn} + * (src[n][c][h][w] - \mu[c])^2\f$, + * + * and @c eps is a constant to improve numerical stability. + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * Batch normalization supports different flavors controlled by + * mkldnn_batch_normalization_desc_t. For example, batch normalization can + * compute the mean and variance on its own or take them as inputs. It can + * either perform scaling and shifting using gamma and beta parameters or not. + * Optionally it can also perform a fused ReLU, which in case of training would + * also require a workspace. + * + * @sa mkldnn_batch_normalization_desc_t + * @{ */ + +/** Initializes a batch normalization descriptor @p bnrm_desc for forward + * propagation using @p prop_kind (possible values are + * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor + * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit + * flags of type mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - variance (#mkldnn_query_src_md, 2), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - mean (#mkldnn_query_dst_md, 1), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * @p prop_kind = #mkldnn_forward_training + * - variance (#mkldnn_query_dst_md, 2), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * + * @note In-place operation is supported; that is, dst points to the same memory + * as src. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** Initializes a batch normalization descriptor @p bnrm_desc for backward + * propagation with respect to data and scale-shift parameters using memory + * descriptors @p data_desc and @p diff_data_desc, normalization parameter + * @p epsilon, and @p flags set using bit flags of type + * mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1) + * - variance (#mkldnn_query_src_md, 2) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_backward + * + * @note in-place operation is supported, + * i.e. diff_src points to the same memory as diff_dst. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** @} */ + +/** @addtogroup c_api_inner_product Inner product + * A primitive to compute an inner product. + * + * Inner product layer is also known as fully connected layer. + * With spatial dimension: + * + * \f[dst[n][oc] = \sum\limits_{ic, kh, kw} + * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw] + * + bias[oc]\f] + * @{ */ + +/** Initializes an inner product descriptor @p ip_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference) and memory descriptors. In order to create an + * inner product without bias, @p bias_desc should be either @c NULL or a + * pointer to a descriptor with memory format kind equals + * #mkldnn_format_kind_undef. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init( + mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to data using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to weights using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** @} */ + +/** @addtogroup c_api_rnn RNN + * A primitive to compute the common recurrent layer. + * @todo add additional description for the group + * @{ */ + +/** + * Initializes a recurrent cell descriptor @p rnn_cell_desc + * using @p rnn_cell_desc, @p kind (possible values are + * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and + * #mkldnn_gru_linear_before_reset), + * @p f (possible values are #mkldnn_eltwise_relu and + * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping. + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init( + mkldnn_rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, + unsigned int flags, float alpha, float clipping); + +/** Returns the number of gates of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_gates_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Returns the number of states of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_states_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Sets quantization @p scale and @p shift for RNN data tensors. + * For performance reasons, low precision configuration of RNN primitive + * expects input activations to have unsigned int8 data type. Scale and shift + * used to quantize floating point data to unsigned integer must be passed to + * RNN primitive using attributes. + * Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // activations quantization parameters + * float scale = ..., shift = ..; + * + * mkldnn_primitive_attr_t rnn_attr; + * // create default attributes + * mkldnn_primitive_attr_create(&rnn_attr); + * + * // set scale and shift for int8 quantization of activation + * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * Quantization scale and shift are common for src_layer, src_iter, + * dst_iter and dst_layer. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams( + mkldnn_primitive_attr_t attr, const float scale, const float shift); + +/** Sets quantization scales @p weights_scales for RNN weights tensors. + * Low precision configuration of RNN primitive expects input weights to have + * signed int8 data type. Scales used to quantize floating point data + * to signed integer must be passed to RNN primitive using attributes. + * The @p mask argument defines correspondence between output tensor dimensions + * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use + * dedicated scaling factor for each slice of the output tensor over i-th + * dimension. Set @p mask to 0 to use common scaling factor for the whole output + * tensor. Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // unique output scales per output channel + * float weights_scales[dic * n_gates] = { ... }; + * // mask that specifies last two dimensions of ldigo format + * int mask = 0x3; + * + * mkldnn_primitive_attr_t attr; + * // create default attributes + * mkldnn_primitive_attr_create(&attr); + * + * // set output channel-wise weights scales + * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask, + * weights_scales); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * The dimension order is always native and does not depend on the actual + * layout used. For example, 5 dimensional weights always have + * (l, d, i, g, o) logical dimension ordering. + * @note + * Quantization sales are common for weights_layer and weights_iteration + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is user's responsibility + * to set proper values. The following formula must be held: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams ( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *weights_scales); + +/** Initializes a rnn descriptor @p rnn_desc for forward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * @note If @p prop_kind equals #mkldnn_forward_training, you must query a + * workspace memory descriptor before creating the primitive. + * + * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be + * @c NULL or point to a zero memory descriptor, which would indicate that the + * RNN primitive should not use them. + * + * @note All memory descriptors except @p src_iter_desc are allowed to be + * initialized with #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * + * Outputs: + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p prop_kind equals #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc); + +/** Initializes a rnn descriptor @p rnn_desc for backward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * + * @note All memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * @p src_iter_desc (simultaneously with @p diff_src_iter_desc), + * @p bias_desc (simultaneously with @p diff_bias_desc), and + * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to + * either be @c NULL or point to a zero memory descriptor, which would indicate + * that the RNN primitive should not use them. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - diff_dst_layer (#mkldnn_query_diff_dst_md, 0) + * - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0) + * + * Outputs: + * - diff_src_layer (#mkldnn_query_diff_src_md, 0) + * - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used + * - diff_weights_layer (#mkldnn_query_diff_weights_md, 0) + * - diff_weights_iter (#mkldnn_query_diff_weights_md, 1) + * - diff_bias (#mkldnn_query_diff_weights_md, 2), if used + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc, + const mkldnn_memory_desc_t *diff_src_layer_desc, + const mkldnn_memory_desc_t *diff_src_iter_desc, + const mkldnn_memory_desc_t *diff_weights_layer_desc, + const mkldnn_memory_desc_t *diff_weights_iter_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_layer, + const mkldnn_memory_desc_t *diff_dst_iter_desc); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_engine Engine operations + * @{ */ + +/** Returns the number of engines of a particular @p kind. */ +size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind); + +/** Creates an @p engine of particular @p kind and @p index. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, + mkldnn_engine_kind_t kind, size_t index); + +/** Returns the kind of an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine, + mkldnn_engine_kind_t *kind); + +/** Destroys an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_stream Execution stream operations + * @{ */ + +/** Creates an execution @p stream for @p engine and with @p flags. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, + mkldnn_engine_t engine, unsigned flags); + +/** Destroys an execution @p stream. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream); + +/** @} */ + +/** @addtogroup c_api_service Service functions + * @{ */ + +/** Sets verbosity level (print information to stdout). + * Possible levels are: + * - 0 -- no verbose output (default) + * - 1 -- primitive information at execution + * - 2 -- primitive information at creation and execution + * + * @note + * Dumping information might affect performance. + * This setting overrides the MKLDNN_VERBOSE environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level); + +/** Enables or disables dumping of JIT-generated code. + * The enable parameter can be: + * - 0 -- disable + * - any other value -- enable + * + * @note + * This setting overrides the MKLDNN_JIT_DUMP environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable); + +/** Gets library version information. + * Version information includes: + * - major -- major version number + * - minor -- minor version number + * - patch -- patch release number + * - hash -- git commit hash */ +const mkldnn_version_t MKLDNN_API *mkldnn_version(); + +/** @} */ + +/** @addtogroup c_api_blas BLAS functions + * A subset of Basic Linear ALgebra (BLAS) functions to perform + * matrix-matrix multiplication. + * @{ */ + +/** SGEMM performs a matrix-matrix multiplication operation defined as + * + * C := alpha*op( A )*op( B ) + beta*C + * + * where + * - op( X ) is one of op( X ) = X or op( X ) = X**T, + * - alpha and beta are scalars, + * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix + * and C an m by n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different from the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_sgemm( + const char *transa, const char *transb, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, const float *A, const mkldnn_dim_t *lda, + const float *B, const mkldnn_dim_t *ldb, + const float *beta, float *C, const mkldnn_dim_t *ldc); + +/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication + * operation and add the result to a scalar-matrix product. For the final + * result, a vector is added to each row or column of the output matrix. + * The operation is defined as: + * + * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset + * + * where + * - op( X ) = X or op( X ) = X**T, + * - A_offset is an m-by-k matrix with every element equal to the value oa, + * - B_offset is an k-by-n matrix with every element equal to the value ob, + * - C_offset is an m-by-n matrix defined by the oc array, size len: + * - if offsetc = F: len must be at least 1 + * - if offsetc = C: len must be at least max(1, m) + * - if offsetc = R: len must be at least max(1, n) + * - alpha and beta are scalars, and A, B and C are matrices, with op( A ) + * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different compared with the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); + +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); +/** @} */ + +/** @} */ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp new file mode 100644 index 0000000000..581400a013 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp @@ -0,0 +1,2615 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#include +#include +#include +#include + +#include "mkldnn.h" +#endif + +namespace mkldnn { + +/// @addtogroup cpp_api C++ API +/// @{ + +/// @addtogroup cpp_api_utils Utils +/// @{ + +/// A class that provides the destructor for an Intel(R) MKL-DNN C handle +template class handle_traits {}; + +/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base +/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and +/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class +/// can be passed by value. This class enables wrapping: +/// - Newly constructed handles. +/// @n In this case, the constructed handle uses reference counting provided +/// by @p std::shared_ptr with a proper deleter function specified through +/// the @p handle_traits class. +/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for +/// example, through mkldnn_primitive_get_primitive_desc()). +/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a +/// deleter because it is assumed that the handle wrapper for the original +/// object deletes the handle (this model is similar to @p std::weak_ptr). +template > class handle { +private: + std::shared_ptr::type> _data; + handle(const handle &&) = delete; + handle &operator=(const handle &&other) = delete; +protected: + bool operator==(const T other) const { return other == _data.get(); } + bool operator!=(const T other) const { return !(*this == other); } +public: + /// Constructs a C handle wrapper. + /// @param t The C handle to wrap. + /// @param weak A flag to specify whether to construct a weak wrapper. + handle(T t = 0, bool weak = false): _data(0) { + reset(t, weak); + } + + handle(const handle &other): _data(other._data) {} + handle &operator=(const handle &other) { + _data = other._data; + return *this; + } + /// Resets the value of a C handle. + /// @param t The new value of the C handle. + /// @param weak A flag to specify whether the wrapper should be weak. + void reset(T t, bool weak = false) { + auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; + _data.reset(t, weak ? dummy_destructor : traits::destructor); + } + + /// Returns the value of the underlying C handle. + T get() const { return _data.get(); } + + bool operator==(const handle &other) const { return other._data.get() == _data.get(); } + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_memory_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_destroy; +}; + +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; +}; +#endif + +struct memory; +struct primitive_desc; + +/// Base class for all computational primitives. +class primitive: public handle { + friend struct error; + friend struct stream; + using handle::handle; +public: + /// A proxy to C primitive kind enum + enum class kind { + undefined_primitive = mkldnn_undefined_primitive, + reorder = mkldnn_reorder, + concat = mkldnn_concat, + sum = mkldnn_sum, + convolution = mkldnn_convolution, + deconvolution = mkldnn_deconvolution, + shuffle = mkldnn_shuffle, + eltwise = mkldnn_eltwise, + softmax = mkldnn_softmax, + pooling = mkldnn_pooling, + lrn = mkldnn_lrn, + batch_normalization = mkldnn_batch_normalization, + inner_product = mkldnn_inner_product, + rnn = mkldnn_rnn, + }; + + primitive(const_mkldnn_primitive_desc_t c_pd); + primitive(const primitive_desc &pd); + + /// Returns the descriptor of the underlying C API primitive. + inline const_mkldnn_primitive_desc_t get_primitive_desc() const; + // TODO: use the C++ API wrapper structure. + + void execute(struct stream &astream, + const std::unordered_map &args) const; +}; + +inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast(akind); +} +/// Intel(R) MKL-DNN exception class. +/// +/// This class captures the status returned by the failed C API function, error +/// message, and, optionally, handle of the primitive that caused the error. +struct error: public std::exception { + mkldnn_status_t status; + const char *message; + + /// Constructs an error instance. + /// + /// @param astatus The error status returned by the C API. + /// @param amessage The error message. + error(mkldnn_status_t astatus, const char *amessage) + : status(astatus), message(amessage) {} + + /// A convenience function for wrapping calls to the C API. Checks the + /// return status and throws an #error in case of failure. + /// + /// @param status The error status returned by the C API. + /// @param message The error message. + static void wrap_c_api(mkldnn_status_t status, const char *message) { + if (status != mkldnn_success) + throw error(status, message); + } +}; + +const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { + const_mkldnn_primitive_desc_t pd; + error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), + "could not get primitive descriptor by primitive"); + return pd; +} +/// @} + +/// @addtogroup cpp_api_enums Common data types and enumerations +/// A proxy to @ref c_api_types in @ref c_api. +/// +/// @{ + +enum scratchpad_mode { + scratchpad_mode_library = mkldnn_scratchpad_mode_library, + scratchpad_mode_user = mkldnn_scratchpad_mode_user, +}; + +inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { + return static_cast(mode); +} + +enum padding_kind { + zero = mkldnn_padding_zero +}; + +inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { + return static_cast(kind); +} + +enum prop_kind { + forward_training = mkldnn_forward_training, + forward_scoring = mkldnn_forward_scoring, + forward_inference = mkldnn_forward_inference, + forward = mkldnn_forward, + backward = mkldnn_backward, + backward_data = mkldnn_backward_data, + backward_weights = mkldnn_backward_weights, + backward_bias = mkldnn_backward_bias +}; + +inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { + return static_cast(kind); +} + +enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, + convolution_auto = mkldnn_convolution_auto, + convolution_direct = mkldnn_convolution_direct, + convolution_winograd = mkldnn_convolution_winograd, + deconvolution_direct = mkldnn_deconvolution_direct, + deconvolution_winograd = mkldnn_deconvolution_winograd, + eltwise_relu = mkldnn_eltwise_relu, + eltwise_tanh = mkldnn_eltwise_tanh, + eltwise_elu = mkldnn_eltwise_elu, + eltwise_square = mkldnn_eltwise_square, + eltwise_abs = mkldnn_eltwise_abs, + eltwise_sqrt = mkldnn_eltwise_sqrt, + eltwise_linear = mkldnn_eltwise_linear, + eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, + eltwise_soft_relu = mkldnn_eltwise_soft_relu, + eltwise_logistic = mkldnn_eltwise_logistic, + lrn_across_channels = mkldnn_lrn_across_channels, + lrn_within_channel = mkldnn_lrn_within_channel, + pooling_max = mkldnn_pooling_max, + pooling_avg = mkldnn_pooling_avg, + pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, + gru_linear_before_reset = mkldnn_gru_linear_before_reset +}; + +inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast(aalgorithm); +} + +enum batch_normalization_flag { + use_global_stats = mkldnn_use_global_stats, + use_scale_shift = mkldnn_use_scaleshift, + fuse_bn_relu = mkldnn_fuse_bn_relu +}; + +inline mkldnn_batch_normalization_flag_t convert_to_c( + batch_normalization_flag aflag) { + return static_cast(aflag); +} + +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast(adir); +} + +enum query { + undef = mkldnn_query_undef, + + query_engine = mkldnn_query_engine, + primitive_kind = mkldnn_query_primitive_kind, + + num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, + num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, + + time_estimate_f64 = mkldnn_query_time_estimate_f64, + memory_consumption_s64 = mkldnn_query_memory_consumption_s64, + + query_scratchpad_engine = mkldnn_query_scratchpad_engine, + + impl_info_str = mkldnn_query_impl_info_str, + + op_d = mkldnn_query_op_d, + convolution_d = mkldnn_query_convolution_d, + deconvolution_d = mkldnn_query_deconvolution_d, + shuffle_d = mkldnn_query_shuffle_d, + eltwise_d = mkldnn_query_eltwise_d, + softmax_d = mkldnn_query_softmax_d, + pooling_d = mkldnn_query_pooling_d, + lrn_d = mkldnn_query_lrn_d, + batch_normalization_d = mkldnn_query_batch_normalization_d, + inner_product_d = mkldnn_query_inner_product_d, + rnn_d = mkldnn_query_rnn_d, + + src_md = mkldnn_query_src_md, + diff_src_md = mkldnn_query_diff_src_md, + weights_md = mkldnn_query_weights_md, + diff_weights_md = mkldnn_query_diff_weights_md, + dst_md = mkldnn_query_dst_md, + diff_dst_md = mkldnn_query_diff_dst_md, + workspace_md = mkldnn_query_workspace_md, + scratchpad_md = mkldnn_query_scratchpad_md, +}; + +inline mkldnn_query_t convert_to_c(query aquery) { + return static_cast(aquery); +} + +/// @} + +/// @addtogroup cpp_api_attr Attributes +/// An extension for controlling primitive behavior. +/// +/// @sa @ref c_api_attributes in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_post_ops_destroy; +}; +#endif + +struct post_ops: public handle { + post_ops() { + mkldnn_post_ops_t result; + error::wrap_c_api(mkldnn_post_ops_create(&result), + "could not create post operation sequence"); + reset(result); + } + + int len() const { return mkldnn_post_ops_len(get()); } + + primitive::kind kind(int index) const { + error::wrap_c_api( + index < len() ? mkldnn_success : mkldnn_invalid_arguments, + "post_ops index is out of range"); + return static_cast(mkldnn_post_ops_get_kind(get(), + index)); + } + + void append_sum(float scale = 1.) { + error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), + "could not append sum"); + } + + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), + "could not get sum params"); + } + + void append_eltwise(float scale, algorithm alg, float alpha, + float beta) { + error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, + convert_to_c(alg), alpha, beta), + "could not append eltwise"); + } + + void get_params_eltwise(int index, float &scale, algorithm &alg, + float &alpha, float &beta) const { + mkldnn_alg_kind_t c_alg; + error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, + &scale, &c_alg, &alpha, &beta), + "could not get eltwise params"); + alg = static_cast(c_alg); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_primitive_attr_destroy; +}; +#endif + +struct primitive_attr: public handle { + primitive_attr() { + mkldnn_primitive_attr_t result; + error::wrap_c_api(mkldnn_primitive_attr_create(&result), + "could not create a primitive attr"); + reset(result); + } + + scratchpad_mode get_scratchpad_mode() const { + mkldnn_scratchpad_mode_t result; + error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( + get(), &result), "could not get scratchpad mode"); + return scratchpad_mode(result); + } + + void set_scratchpad_mode(scratchpad_mode mode) { + error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( + get(), mkldnn::convert_to_c(mode)), + "could not set scratchpad mode"); + } + + void get_output_scales(int &mask, std::vector &scales) const + { + mkldnn_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), + &count, &c_mask, &c_scales), + "could not get int output scales"); + scales.resize(count); + + mask = c_mask; + for (mkldnn_dim_t c = 0; c < count; ++c) + scales[c] = c_scales[c]; + } + + void set_output_scales(int mask, const std::vector &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), + (mkldnn_dim_t)scales.size(), mask, &scales[0]), + "could not set int output scales"); + } + + const post_ops get_post_ops() const { + post_ops result; + const_mkldnn_post_ops_t c_result; + error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), + "could not get post operation sequence"); + result.reset(const_cast(c_result), true); + return result; + } + + void set_post_ops(post_ops ops) { + error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post operation sequence"); + } + + void set_rnn_data_qparams(const float scale, const float shift) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), + scale, shift), "could not set rnn data int scale/shift"); + } + + void set_rnn_weights_qparams(int mask, const std::vector &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), + (int)scales.size(), mask, &scales[0]), + "could not set rnn weights int scales"); + } +}; + +/// @} + +/// @addtogroup cpp_api_engine Engine +/// Engine operations. +/// +/// @sa @ref c_api_engine in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_engine_destroy; +}; +#endif + +/// An execution engine. +struct engine: public handle { + friend class primitive; + // gcc bug??? using handle::handle; + + /// Kinds of engines. + enum kind { + /// An unspecified engine + any = mkldnn_any_engine, + /// CPU engine + cpu = mkldnn_cpu, + }; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + + static size_t get_count(kind akind) { + return mkldnn_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + + engine(kind akind, size_t index) { + mkldnn_engine_t aengine; + error::wrap_c_api( + mkldnn_engine_create(&aengine, + convert_to_c(akind), index), + "could not create an engine"); + reset(aengine); + } + + explicit engine(const mkldnn_engine_t& aengine) + : handle(aengine, true) {} + + engine(const handle &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + reset(engine_q, true); + } + + template + static engine query(const primitive_desc &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + + return engine(engine_q); + } + +private: + static mkldnn_engine_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } +}; + +/// @} + +/// @addtogroup cpp_api_stream Stream +/// Execution stream operations +/// +/// @sa @ref c_api_stream in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits { + static constexpr auto destructor = &mkldnn_stream_destroy; +}; +#endif + +struct stream: public handle { + using handle::handle; + + enum: unsigned { + default_flags = mkldnn_stream_default_flags, + }; + + /// Constructs a stream. + stream(const engine &aengine, + unsigned flags = static_cast(default_flags)) { + mkldnn_stream_t astream; + error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), + "could not create a stream"); + reset(astream); + } +}; + +/// @} + +/// @addtogroup cpp_api_memory_related Memory and memory related operations +/// @{ + +/// @addtogroup cpp_api_memory Memory +/// A primitive to describe and store data. +/// +/// For more information, refer to @ref c_api_memory in @ref c_api. +/// @{ + +/// Memory that describes the data. +struct memory: public handle { + public: + typedef mkldnn_dim_t dim; + typedef std::vector dims; + + template static void validate_dims(const std::vector &v) { + if (v.size() > MKLDNN_MAX_NDIMS) + throw error(mkldnn_invalid_arguments, "invalid dimensions"); + } + + /// Data type specification. See #mkldnn_data_type_t for a detailed + /// description. + enum data_type { + data_undef = mkldnn_data_type_undef, + f32 = mkldnn_f32, + s32 = mkldnn_s32, + s8 = mkldnn_s8, + u8 = mkldnn_u8, + }; + + /// Memory format tag specification. See #mkldnn_format_tag_t + /// for a detailed description. + enum format_tag { + format_tag_undef = mkldnn_format_tag_undef, + any = mkldnn_format_tag_any, + a = mkldnn_a, + ab = mkldnn_ab, + abc = mkldnn_abc, + abcd = mkldnn_abcd, + abcde = mkldnn_abcde, + abcdef = mkldnn_abcdef, + abdec = mkldnn_abdec, + acb = mkldnn_acb, + acbde = mkldnn_acbde, + acdb = mkldnn_acdb, + acdeb = mkldnn_acdeb, + ba = mkldnn_ba, + bac = mkldnn_bac, + bacd = mkldnn_bacd, + bcda = mkldnn_bcda, + cba = mkldnn_cba, + cdba = mkldnn_cdba, + cdeba = mkldnn_cdeba, + decab = mkldnn_decab, + Abc16a = mkldnn_Abc16a, + ABc16a16b = mkldnn_ABc16a16b, + aBc16b = mkldnn_aBc16b, + ABc16b16a = mkldnn_ABc16b16a, + Abc4a = mkldnn_Abc4a, + aBc4b = mkldnn_aBc4b, + ABc4b16a4b = mkldnn_ABc4b16a4b, + ABc4b4a = mkldnn_ABc4b4a, + ABc8a16b2a = mkldnn_ABc8a16b2a, + ABc8a8b = mkldnn_ABc8a8b, + aBc8b = mkldnn_aBc8b, + ABc8b16a2b = mkldnn_ABc8b16a2b, + ABc8b8a = mkldnn_ABc8b8a, + Abcd16a = mkldnn_Abcd16a, + ABcd16a16b = mkldnn_ABcd16a16b, + aBcd16b = mkldnn_aBcd16b, + ABcd16b16a = mkldnn_ABcd16b16a, + aBCd16b16c = mkldnn_aBCd16b16c, + aBCd16c16b = mkldnn_aBCd16c16b, + Abcd4a = mkldnn_Abcd4a, + aBcd4b = mkldnn_aBcd4b, + ABcd4b16a4b = mkldnn_ABcd4b16a4b, + ABcd4b4a = mkldnn_ABcd4b4a, + aBCd4c16b4c = mkldnn_aBCd4c16b4c, + aBCd4c4b = mkldnn_aBCd4c4b, + ABcd8a16b2a = mkldnn_ABcd8a16b2a, + ABcd8a8b = mkldnn_ABcd8a8b, + aBcd8b = mkldnn_aBcd8b, + ABcd8b16a2b = mkldnn_ABcd8b16a2b, + aBCd8b16c2b = mkldnn_aBCd8b16c2b, + ABcd8b8a = mkldnn_ABcd8b8a, + aBCd8b8c = mkldnn_aBCd8b8c, + aBCd8c16b2c = mkldnn_aBCd8c16b2c, + aBCd8c8b = mkldnn_aBCd8c8b, + Abcde16a = mkldnn_Abcde16a, + ABcde16a16b = mkldnn_ABcde16a16b, + aBcde16b = mkldnn_aBcde16b, + ABcde16b16a = mkldnn_ABcde16b16a, + aBCde16b16c = mkldnn_aBCde16b16c, + aBCde16c16b = mkldnn_aBCde16c16b, + aBCde2c8b4c = mkldnn_aBCde2c8b4c, + Abcde4a = mkldnn_Abcde4a, + aBcde4b = mkldnn_aBcde4b, + ABcde4b4a = mkldnn_ABcde4b4a, + aBCde4b4c = mkldnn_aBCde4b4c, + aBCde4c16b4c = mkldnn_aBCde4c16b4c, + aBCde4c4b = mkldnn_aBCde4c4b, + Abcde8a = mkldnn_Abcde8a, + ABcde8a8b = mkldnn_ABcde8a8b, + aBcde8b = mkldnn_aBcde8b, + ABcde8b16a2b = mkldnn_ABcde8b16a2b, + aBCde8b16c2b = mkldnn_aBCde8b16c2b, + ABcde8b8a = mkldnn_ABcde8b8a, + aBCde8b8c = mkldnn_aBCde8b8c, + aBCde8c16b2c = mkldnn_aBCde8c16b2c, + aBCde8c8b = mkldnn_aBCde8c8b, + aBcdef16b = mkldnn_aBcdef16b, + aBCdef16b16c = mkldnn_aBCdef16b16c, + aBCdef16c16b = mkldnn_aBCdef16c16b, + aBcdef4b = mkldnn_aBcdef4b, + aBCdef4c4b = mkldnn_aBCdef4c4b, + aBCdef8b8c = mkldnn_aBCdef8b8c, + aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, + aBCdef8c8b = mkldnn_aBCdef8c8b, + aBdc16b = mkldnn_aBdc16b, + aBdc4b = mkldnn_aBdc4b, + aBdc8b = mkldnn_aBdc8b, + aBdec16b = mkldnn_aBdec16b, + aBdec4b = mkldnn_aBdec4b, + aBdec8b = mkldnn_aBdec8b, + aBdefc16b = mkldnn_aBdefc16b, + aBdefc4b = mkldnn_aBdefc4b, + aBdefc8b = mkldnn_aBdefc8b, + Acb16a = mkldnn_Acb16a, + Acb4a = mkldnn_Acb4a, + Acb8a = mkldnn_Acb8a, + aCBd16b16c = mkldnn_aCBd16b16c, + aCBde16b16c = mkldnn_aCBde16b16c, + Acdb16a = mkldnn_Acdb16a, + Acdb4a = mkldnn_Acdb4a, + Acdb8a = mkldnn_Acdb8a, + Acdeb16a = mkldnn_Acdeb16a, + Acdeb4a = mkldnn_Acdeb4a, + Acdeb8a = mkldnn_Acdeb8a, + BAc16a16b = mkldnn_BAc16a16b, + BAcd16a16b = mkldnn_BAcd16a16b, + format_tag_last = mkldnn_format_tag_last, + + x = mkldnn_x, + nc = mkldnn_nc, + cn = mkldnn_cn, + ncw = mkldnn_ncw, + nwc = mkldnn_nwc, + nchw = mkldnn_nchw, + nhwc = mkldnn_nhwc, + chwn = mkldnn_chwn, + ncdhw = mkldnn_ncdhw, + ndhwc = mkldnn_ndhwc, + oi = mkldnn_oi, + io = mkldnn_io, + oiw = mkldnn_oiw, + wio = mkldnn_wio, + oihw = mkldnn_oihw, + hwio = mkldnn_hwio, + ihwo = mkldnn_ihwo, + iohw = mkldnn_iohw, + oidhw = mkldnn_oidhw, + dhwio = mkldnn_dhwio, + goiw = mkldnn_goiw, + goihw = mkldnn_goihw, + hwigo = mkldnn_hwigo, + giohw = mkldnn_giohw, + goidhw = mkldnn_goidhw, + tnc = mkldnn_tnc, + ntc = mkldnn_ntc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldgoi = mkldnn_ldgoi, + ldgo = mkldnn_ldgo, + nCdhw16c = mkldnn_nCdhw16c, + nCdhw4c = mkldnn_nCdhw4c, + nCdhw8c = mkldnn_nCdhw8c, + nChw16c = mkldnn_nChw16c, + nChw4c = mkldnn_nChw4c, + nChw8c = mkldnn_nChw8c, + nCw16c = mkldnn_nCw16c, + nCw4c = mkldnn_nCw4c, + nCw8c = mkldnn_nCw8c, + IOw16o16i = mkldnn_IOw16o16i, + OIw16i16o = mkldnn_OIw16i16o, + OIw16o16i = mkldnn_OIw16o16i, + Oiw16o = mkldnn_Oiw16o, + OIw4i16o4i = mkldnn_OIw4i16o4i, + OIw4i4o = mkldnn_OIw4i4o, + Oiw4o = mkldnn_Oiw4o, + OIw8i16o2i = mkldnn_OIw8i16o2i, + OIw8i8o = mkldnn_OIw8i8o, + OIw8o16i2o = mkldnn_OIw8o16i2o, + OIw8o8i = mkldnn_OIw8o8i, + Owi16o = mkldnn_Owi16o, + Owi4o = mkldnn_Owi4o, + Owi8o = mkldnn_Owi8o, + IOhw16o16i = mkldnn_IOhw16o16i, + Ohwi16o = mkldnn_Ohwi16o, + Ohwi4o = mkldnn_Ohwi4o, + Ohwi8o = mkldnn_Ohwi8o, + OIhw16i16o = mkldnn_OIhw16i16o, + OIhw16o16i = mkldnn_OIhw16o16i, + Oihw16o = mkldnn_Oihw16o, + OIhw4i16o4i = mkldnn_OIhw4i16o4i, + OIhw4i4o = mkldnn_OIhw4i4o, + Oihw4o = mkldnn_Oihw4o, + OIhw8i16o2i = mkldnn_OIhw8i16o2i, + OIhw8i8o = mkldnn_OIhw8i8o, + OIhw8o16i2o = mkldnn_OIhw8o16i2o, + OIhw8o8i = mkldnn_OIhw8o8i, + Odhwi16o = mkldnn_Odhwi16o, + Odhwi4o = mkldnn_Odhwi4o, + Odhwi8o = mkldnn_Odhwi8o, + OIdhw16i16o = mkldnn_OIdhw16i16o, + OIdhw16o16i = mkldnn_OIdhw16o16i, + Oidhw16o = mkldnn_Oidhw16o, + OIdhw4i4o = mkldnn_OIdhw4i4o, + Oidhw4o = mkldnn_Oidhw4o, + OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, + OIdhw8i8o = mkldnn_OIdhw8i8o, + OIdhw8o8i = mkldnn_OIdhw8o8i, + gIOw16o16i = mkldnn_gIOw16o16i, + gOIw16i16o = mkldnn_gOIw16i16o, + gOIw16o16i = mkldnn_gOIw16o16i, + gOiw16o = mkldnn_gOiw16o, + gOIw4i16o4i = mkldnn_gOIw4i16o4i, + gOIw4i4o = mkldnn_gOIw4i4o, + gOiw4o = mkldnn_gOiw4o, + gOIw8i16o2i = mkldnn_gOIw8i16o2i, + gOIw8i8o = mkldnn_gOIw8i8o, + gOIw8o16i2o = mkldnn_gOIw8o16i2o, + gOIw8o8i = mkldnn_gOIw8o8i, + gOwi16o = mkldnn_gOwi16o, + gOwi4o = mkldnn_gOwi4o, + gOwi8o = mkldnn_gOwi8o, + gIOhw16o16i = mkldnn_gIOhw16o16i, + gOhwi16o = mkldnn_gOhwi16o, + gOhwi4o = mkldnn_gOhwi4o, + gOhwi8o = mkldnn_gOhwi8o, + Goihw16g = mkldnn_Goihw16g, + gOIhw16i16o = mkldnn_gOIhw16i16o, + gOIhw16o16i = mkldnn_gOIhw16o16i, + gOihw16o = mkldnn_gOihw16o, + gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, + gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, + gOIhw4i4o = mkldnn_gOIhw4i4o, + gOIhw4o4i = mkldnn_gOIhw4o4i, + gOihw4o = mkldnn_gOihw4o, + Goihw8g = mkldnn_Goihw8g, + gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, + gOIhw8i8o = mkldnn_gOIhw8i8o, + gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, + gOIhw8o8i = mkldnn_gOIhw8o8i, + gOdhwi16o = mkldnn_gOdhwi16o, + gOdhwi4o = mkldnn_gOdhwi4o, + gOdhwi8o = mkldnn_gOdhwi8o, + gOIdhw16i16o = mkldnn_gOIdhw16i16o, + gOIdhw16o16i = mkldnn_gOIdhw16o16i, + gOidhw16o = mkldnn_gOidhw16o, + gOIdhw4i4o = mkldnn_gOIdhw4i4o, + gOidhw4o = mkldnn_gOidhw4o, + gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, + gOIdhw8i8o = mkldnn_gOIdhw8i8o, + gOIdhw8o8i = mkldnn_gOIdhw8o8i, + }; + + /// A memory descriptor. + struct desc { + friend struct memory; + /// The underlying C API data structure. + mkldnn_memory_desc_t data; + + /// Constructs a zero memory descriptor + desc(): data() {} + + /// Constructs a memory descriptor. + /// + /// @param adims Data dimensions + /// @param adata_type Data precision/type. + /// @param aformat Data layout format tag. + desc(const dims &adims, data_type adata_type, + format_tag aformat) { + validate_dims(adims); + error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), + adims.size() == 0 ? nullptr : &adims[0], + convert_to_c(adata_type), convert_to_c(aformat)), + "could not initialize a memory descriptor"); + } + + /// Constructs a memory descriptor from a C API data structure. + /// + /// @param adata A C API #mkldnn_memory_desc_t structure. + desc(const mkldnn_memory_desc_t &adata): data(adata) {} + + /// Constructs a sub-memory descriptor + // + /// @param adims Sizes of a sub-memory + /// @param offsets Offsets of a sub-memory + desc submemory_desc(const dims &adims, const dims &offsets) { + mkldnn_memory_desc_t sub_md; + error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, + &data, &adims[0], &offsets[0]), + "could not initialize a sub-memory"); + return desc(sub_md); + } + + /// Returns the number of bytes required to allocate the memory described + /// including the padding area. + size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } + + bool operator==(const desc &other) const { + return mkldnn_memory_desc_equal(&data, &other.data) != 0; + } + + bool operator!=(const desc &other) const { return !operator==(other); } + }; + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + /// @param ahandle Native handle. + memory(const desc &md, const engine &aengine, void *ahandle) { + mkldnn_memory_t result; + error::wrap_c_api(mkldnn_memory_create(&result, &md.data, + aengine.get(), ahandle), "could not create a memory"); + reset(result); + } + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + memory(const desc &md, const engine &aengine) + : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} + + /// Returns the descriptor of the memory. + desc get_desc() const { + const mkldnn_memory_desc_t *cdesc; + error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), + "could not get memory descriptor from a memory"); + return desc(*cdesc); + } + + /// Returns the engine of the memory. + engine get_engine() const { + mkldnn_engine_t engine_q; + error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), + "could not get engine from a memory"); + return engine(engine_q); + } + + /// Returns a handle of the data contained in the memory. + /// + /// On the CPU engine, this is a pointer to the allocated memory. + void *get_data_handle() const { + void *handle; + error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), + "could not get native handle"); + return handle; + } + + void set_data_handle(void *handle) const { + error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), + "could not set native handle"); + } + + // Must go away or be private: + static mkldnn_data_type_t convert_to_c(data_type adata_type) { + return static_cast(adata_type); + } + static mkldnn_format_tag_t convert_to_c(format_tag aformat) { + return static_cast(aformat); + } +}; + +inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { + return !(a == b); +} + +inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { + return !(a == b); +} +inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { + return b == a; +} +inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { + return !(a == b); +} + +/// @} + +/// @addtogroup cpp_api_reorder Reorder +/// A primitive to copy data between memory formats. +/// +/// @sa @ref c_api_reorder in @ref c_api +/// @{ + +struct reorder : public primitive { + struct primitive_desc : public handle { + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine scratchpad_engine() { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), + "could not get scratchpad engine from reorder primitive_desc"); + + return engine(engine_q); + } + + engine get_engine() { return engine::query(*this); } + }; + + reorder(const primitive_desc &pd): primitive(pd.get()) {} + + reorder(const memory &src, const memory &dst): + primitive(primitive_desc(src, dst).get()) {} + + void execute(stream astream, memory &src, memory &dst) { + primitive::execute(astream, + {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); + } +}; + +/// @} + +/// @addtogroup cpp_api_concat Concat +/// A primitive to concatenate data by arbitrary dimension. +/// +/// @sa @ref c_api_concat in @ref c_api +/// @{ + +struct concat : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + const std::vector &srcs) { + std::vector c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, int concat_dimension, + const std::vector &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + primitive_desc(int concat_dimension, + const std::vector &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, nullptr, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + concat(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @addtogroup cpp_api_sum Sum +/// A primitive to sum data. +/// +/// @sa @ref c_api_sum in @ref c_api +/// @{ + +struct sum : public primitive { + struct primitive_desc : public handle { + std::vector cpp_to_c( + const std::vector &srcs) { + std::vector c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, + const std::vector &scales, + const std::vector &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + &scales[0], &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + primitive_desc(const std::vector &scales, + const std::vector &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, + nullptr, (int)c_api_srcs.size(), &scales[0], + &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + sum(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @} + +/// @addtogroup cpp_api_primitives Primitives +/// @{ + +/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors +/// @{ + +/// A base class for all primitive descriptors. +struct primitive_desc : public handle { + primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, + const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { + mkldnn_primitive_desc_iterator_t iterator = nullptr; + mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( + &iterator, desc, attr ? attr->get() : nullptr, e.get(), + hint_fwd_pd); + error::wrap_c_api(status, + "could not create a primitive descriptor iterator"); + pd_iterator.reset(iterator); + fetch_impl(); + } + + engine get_engine() { return engine::query(*this); } + + primitive_attr get_primitive_attr() const { + const_mkldnn_primitive_attr_t const_cattr; + error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), + "could not get attributes"); + mkldnn_primitive_attr_t cattr; + error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), + "could not clone attributes"); + + primitive_attr attr; + attr.reset(cattr); + return attr; + } + + /// Returns implementation name + const char *impl_info_str() const { + const char *res; + error::wrap_c_api(mkldnn_primitive_desc_query(get(), + mkldnn_query_impl_info_str, 0, &res), + "could not query implementation info string"); + return res; + } + + /// Queries the memory::dim value (same as int64_t) + memory::dim query_s64(query q) const { + memory::dim res; + mkldnn_status_t status = mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(q), 0, &res); + return status == mkldnn_success ? res : 0; + } + + /// Advances the next implementation for the given op descriptor. + /// + /// Returns: + /// - @c true on success + /// - @c false if the last implementation reached, and + /// the primitive descriptor itself is kept unchanged + bool next_impl() { + mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( + pd_iterator.get()); + if (status == mkldnn_iterator_ends) return false; + error::wrap_c_api(status, "primitive descriptor iterator next failed"); + + fetch_impl(); + return true; + } + + /// Queries and returns requested memory descriptor. + memory::desc query_md(query what, int idx = 0) const { + std::vector valid_q{src_md, diff_src_md, weights_md, + diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), + [=](query q) { return what == q; })) + throw error(mkldnn_invalid_arguments, "invalid memory query"); + + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(what), idx); + if (cdesc == nullptr) return memory::desc(); + + return memory::desc(*cdesc); + } + + // register specialized queries, e.g. src_desc() +# define REG_QUERY_MD(name, what, idx) \ + memory::desc name ## _desc() const { return query_md(what ## _md, idx); } + + private: + handle pd_iterator; + void fetch_impl() { + mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( + pd_iterator.get()); + error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, + "could not fetch a primitive descriptor from the iterator"); + reset(pd); + } +}; + +/// @} + +/// @addtogroup cpp_api_convolution Convolution +/// A primitive to compute convolution using different algorithms. +/// +/// @sa @ref c_api_convolution in @ref c_api +/// @{ + +struct convolution_forward: public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_data : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_weights : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} +// +/// @addtogroup cpp_api_deconvolution Deconvolution +/// A primitive to compute deconvolution using different algorithms. +/// +/// @sa @ref c_api_deconvolution in @ref c_api +/// @{ + +struct deconvolution_forward: public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_data : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_weights : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_lrn LRN +/// A primitive to perform local response normalization (LRN) across or within +/// channels. +/// +/// @sa @ref c_api_lrn in @ref c_api +/// @{ + +struct lrn_forward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, local_size, alpha, beta, k), + "could not create a lrn forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct lrn_backward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(algorithm aalgorithm, const memory::desc &data_desc, + const memory::desc &diff_data_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), &diff_data_desc.data, + &data_desc.data, local_size, alpha, beta, k), + "could not create a lrn backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_pooling Pooling +/// A primitive to perform max or average pooling. +/// +/// @sa @ref c_api_pooling in @ref c_api +/// @{ + +struct pooling_forward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims kernel, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, &dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a forward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct pooling_backward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::dims &strides, + const memory::dims &kernel, + const memory::dims &padding_l, + const memory::dims &padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_src_desc.data, &diff_dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a backward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_eltwise Eltwise +/// A primitive to compute element-wise operations like parametric rectifier +/// linear unit (ReLU). +/// +/// @sa @ref c_api_eltwise in @ref c_api +/// @{ + +struct eltwise_forward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + template + desc(prop_kind aprop_kind, algorithm alg_kind, + const memory::desc &src_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + mkldnn::convert_to_c(alg_kind), &src_desc.data, + static_cast(alpha), static_cast(beta)), + "could not create a eltwise forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct eltwise_backward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + + template + desc(algorithm alg_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, + mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, + &data_desc.data, static_cast(alpha), + static_cast(beta)), + "could not create a eltwise backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_softmax Softmax +/// A primitive to perform softmax. +/// +/// @sa @ref c_api_softmax in @ref c_api +/// @{ + +struct softmax_forward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + softmax_axis), + "could not create a softmax forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct softmax_backward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(const memory::desc &diff_desc, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, + &diff_desc.data, &data_desc.data, softmax_axis), + "could not init a backward softmax descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_batch_norm Batch normalization +/// A primitive to perform batch normalization. +/// +/// @sa @ref c_api_batch_normalization in @ref c_api +/// @{ + +struct batch_normalization_forward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + static_cast(epsilon), flags), + "could not create a batch normalization forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + + memory::desc mean_desc() const { return stat_desc(mean); } + memory::desc variance_desc() const { return stat_desc(var); } + + private: + enum { mean = 1, var = 2, }; + memory::desc stat_desc(int kind) const { + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api(mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); + } + }; + + batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct batch_normalization_backward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template + desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T epsilon, unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + &diff_data_desc.data, &data_desc.data, + static_cast(epsilon), flags), + "could not create a batch normalization backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(mean, src, 1); + REG_QUERY_MD(variance, src, 2); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_inner_product Inner Product +/// A primitive to compute an inner product. +/// +/// @sa @ref c_api_inner_product in @ref c_api +/// @{ + +struct inner_product_forward: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, &bias_desc.data, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, nullptr, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_data: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_data_desc_init(&data, + &diff_src_desc.data, &weights_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_weights: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + &diff_bias_desc.data, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + nullptr, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_rnn RNN +/// A primitive to compute common recurrent layer. +/// +/// @sa @ref c_api_rnn in @ref c_api +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), 0U, 0, 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const + { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const + { return algorithm(c_rnn_cell_.activation_kind); } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc + ) { + error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + rnn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data, + &diff_src_layer_desc.data, &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, &diff_bias_desc.data, + &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src_layer, diff_src, 0); + REG_QUERY_MD(diff_src_iter, diff_src, 1); + REG_QUERY_MD(diff_weights_layer, diff_weights, 0); + REG_QUERY_MD(diff_weights_iter, diff_weights, 1); + REG_QUERY_MD(diff_bias, diff_weights, 2); + REG_QUERY_MD(diff_dst_layer, diff_dst, 0); + REG_QUERY_MD(diff_dst_iter, diff_dst, 1); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_shuffle Shuffle +/// A primitive to shuffle data along the axis. +/// +/// @sa @ref c_api_shuffle in @ref c_api +/// @{ + +struct shuffle_forward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + axis, group_size), + "could not create a shuffle forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct shuffle_backward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(const memory::desc &diff_data_desc, int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, + &diff_data_desc.data, axis, group_size), + "could not create a shuffle backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const shuffle_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @} Primitives + +/// @} C++ API + +#undef REG_QUERY_MD + +// implementation section +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { + mkldnn_primitive_t result; + error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), + "could not create a primitive"); + reset(result); +} + +inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} + +inline void primitive::execute(stream &astream, + const std::unordered_map &args) const { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a: args) + c_args.push_back({a.first, a.second.get()}); + + error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), + (int)c_args.size(), c_args.data()), + "primitive execution fail"); +} +#endif // DOXYGEN_SHOULD_SKIP_THIS + +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h new file mode 100644 index 0000000000..f4dc2fdfa6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#ifndef MKLDNN_DEBUG_H +#define MKLDNN_DEBUG_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v); +const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v); +const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v); +const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v); +const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v); +const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v); +const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v); +const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v); + +/** Forms a format string for a given memory descriptor. + * + * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'. + * Here: + * - dt -- data type + * - p -- indicates there is non-trivial padding + * - o -- indicates there is non-trivial padding offset + * - 0 -- indicates there is non-trivial offset0 + * - fmt_kind -- format kind (blocked, wino, etc...) + * - fmt -- extended format string (format_kind specific) + * - extra -- shows extra fields (underspecified) + */ +int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len, + const mkldnn_memory_desc_t *md); + +/** Forms a dimension string for a given memory descriptor. + * + * The format is defined as: 'dim0xdim1x...xdimN + */ +int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len, + const mkldnn_memory_desc_t *md); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h new file mode 100644 index 0000000000..1b6c356982 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h @@ -0,0 +1,1415 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TYPES_H +#define MKLDNN_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include +#include +#endif + +/** @addtogroup c_api C API + * @{ + * + * @addtogroup c_api_types Types + * @{ + * + * @addtogroup c_api_types_generic Generic + * @{ */ + +/** Intel(R) MKL-DNN Version type */ +typedef struct { + int major; + int minor; + int patch; + const char *hash; +} mkldnn_version_t; + +/** Status values returned by Intel(R) MKL-DNN functions. */ +typedef enum { + /** The operation was successful */ + mkldnn_success = 0, + /** The operation failed due to an out-of-memory condition */ + mkldnn_out_of_memory = 1, + /** The operation failed and should be retried */ + mkldnn_try_again = 2, + /** The operation failed because of incorrect function arguments */ + mkldnn_invalid_arguments = 3, + /** The operation failed because a primitive was not ready for execution */ + mkldnn_not_ready = 4, + /** The operation failed because requested functionality is not implemented + */ + mkldnn_unimplemented = 5, + /** Primitive iterator passed over last primitive descriptor */ + mkldnn_iterator_ends = 6, + /** Primitive or engine failed on execution */ + mkldnn_runtime_error = 7, + /** Queried element is not required for given primitive */ + mkldnn_not_required = 8, +} mkldnn_status_t; + +/** Data type specification */ +typedef enum { + /** Undefined data type, used for empty memory descriptors. */ + mkldnn_data_type_undef = 0, + /** 32-bit/single-precision floating point. */ + mkldnn_f32 = 1, + /** 32-bit signed integer. */ + mkldnn_s32 = 2, + /** 8-bit signed integer. */ + mkldnn_s8 = 3, + /** 8-bit unsigned integer. */ + mkldnn_u8 = 4, +} mkldnn_data_type_t; + +/** Memory format kind */ +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_format_kind_undef = 0, + /** Unspecified format. The primitive selects a format automatically. */ + mkldnn_format_kind_any, + /** A tensor in a generic format described by the stride and blocking + * values in each dimension. See #mkldnn_blocking_desc_t for more + * information. */ + mkldnn_blocked, + /** Weights format used in 8bit Winograd convolution */ + mkldnn_format_kind_wino, + /** Packed weights format used in RNN */ + mkldnn_format_kind_rnn_packed, +} mkldnn_format_kind_t; + +/** Memory format tag specification. + * + * Intel MKL-DNN formats describe physical data layout. The physical layout + * is described as a sequence of the dimensions as they are laid out in the + * memory (from the outer-most to the inner-most). Note that this order + * doesn't affect the logical order of the dimensions that is kept in the + * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the + * dimensions is specified by the type of tensor. + * + * For example, CNN 5D tensor always has its logical dimensions in the order + * `(batch, channels, depth, height, width)`, while the physical layout might be + * #mkldnn_ncdhw or #mkldnn_ndhwc: + * + * ~~~cpp + * int batch = 2, channels = 16, depth = 13, height = 13, width = 13; + * + * int ndims = 5; // 5D tensor + * mkldnn_dims_t dims = {batch, channels, depth, height, width}; + * mkldnn_memory_desc_t data_in_ncdhw; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw); + * + * // note that in both cases dims passed are the same + * mkldnn_memory_desc_t data_in_ndhwc; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc); + * ~~~ + * + * The following notation applies to memory format names: + * - @c 'n' denotes the mini-batch dimension + * - @c 'c' denotes a channels dimension + * - When there are multiple channel dimensions (for example, in convolution + * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output + * channels + * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width + * respectively + * - Upper-case letters indicate that the data is laid out in blocks + * for a particular dimension. In such cases, the format name contains both + * upper- and lower-case letters for that dimension with a lower-case letter + * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a + * format where the outermost dimension is mini-batch, followed by the + * channel block number, followed by the spatial height and width, and + * finally followed by 8-element channel blocks. + * + * @note + * Channel designations can be different. For example, both the @c + * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D + * tensor. + * + * @sa @ref understanding_memory_formats + */ +typedef enum { + /** Undefined memory format tag */ + mkldnn_format_tag_undef = 0, + /** Undefined memory format tag. + * The primitive selects a format automatically. */ + mkldnn_format_tag_any, + + /* Semantic agnostic section */ + /* The physical order of dimensions is defined by the permutation of the + * characters, assuming that ab..z defines the natural order. + */ + + /* Plain formats */ + + mkldnn_a, + mkldnn_ab, + mkldnn_abc, + mkldnn_abcd, + mkldnn_abcde, + mkldnn_abcdef, + mkldnn_abdec, + mkldnn_acb, + mkldnn_acbde, + mkldnn_acdb, + mkldnn_acdeb, + mkldnn_ba, + mkldnn_bac, + mkldnn_bacd, + mkldnn_bcda, + mkldnn_cba, + mkldnn_cdba, + mkldnn_cdeba, + mkldnn_decab, + + /* Opaque blocked formats */ + + mkldnn_Abc16a, + mkldnn_ABc16a16b, + mkldnn_aBc16b, + mkldnn_ABc16b16a, + mkldnn_Abc4a, + mkldnn_aBc4b, + mkldnn_ABc4b16a4b, + mkldnn_ABc4b4a, + mkldnn_ABc8a16b2a, + mkldnn_ABc8a8b, + mkldnn_aBc8b, + mkldnn_ABc8b16a2b, + mkldnn_ABc8b8a, + mkldnn_Abcd16a, + mkldnn_ABcd16a16b, + mkldnn_aBcd16b, + mkldnn_ABcd16b16a, + mkldnn_aBCd16b16c, + mkldnn_aBCd16c16b, + mkldnn_Abcd4a, + mkldnn_aBcd4b, + mkldnn_ABcd4b16a4b, + mkldnn_ABcd4b4a, + mkldnn_aBCd4c16b4c, + mkldnn_aBCd4c4b, + mkldnn_ABcd8a16b2a, + mkldnn_ABcd8a8b, + mkldnn_aBcd8b, + mkldnn_ABcd8b16a2b, + mkldnn_aBCd8b16c2b, + mkldnn_ABcd8b8a, + mkldnn_aBCd8b8c, + mkldnn_aBCd8c16b2c, + mkldnn_aBCd8c8b, + mkldnn_Abcde16a, + mkldnn_ABcde16a16b, + mkldnn_aBcde16b, + mkldnn_ABcde16b16a, + mkldnn_aBCde16b16c, + mkldnn_aBCde16c16b, + mkldnn_aBCde2c8b4c, + mkldnn_Abcde4a, + mkldnn_aBcde4b, + mkldnn_ABcde4b4a, + mkldnn_aBCde4b4c, + mkldnn_aBCde4c16b4c, + mkldnn_aBCde4c4b, + mkldnn_Abcde8a, + mkldnn_ABcde8a8b, + mkldnn_aBcde8b, + mkldnn_ABcde8b16a2b, + mkldnn_aBCde8b16c2b, + mkldnn_ABcde8b8a, + mkldnn_aBCde8b8c, + mkldnn_aBCde8c16b2c, + mkldnn_aBCde8c8b, + mkldnn_aBcdef16b, + mkldnn_aBCdef16b16c, + mkldnn_aBCdef16c16b, + mkldnn_aBcdef4b, + mkldnn_aBCdef4c4b, + mkldnn_aBCdef8b8c, + mkldnn_aBCdef8c16b2c, + mkldnn_aBCdef8c8b, + mkldnn_aBdc16b, + mkldnn_aBdc4b, + mkldnn_aBdc8b, + mkldnn_aBdec16b, + mkldnn_aBdec4b, + mkldnn_aBdec8b, + mkldnn_aBdefc16b, + mkldnn_aBdefc4b, + mkldnn_aBdefc8b, + mkldnn_Acb16a, + mkldnn_Acb4a, + mkldnn_Acb8a, + mkldnn_aCBd16b16c, + mkldnn_aCBde16b16c, + mkldnn_Acdb16a, + mkldnn_Acdb4a, + mkldnn_Acdb8a, + mkldnn_Acdeb16a, + mkldnn_Acdeb4a, + mkldnn_Acdeb8a, + mkldnn_BAc16a16b, + mkldnn_BAcd16a16b, + + /** Just a sentinel, not real memory format tag. Must be changed after new + * format tag is added. */ + mkldnn_format_tag_last, + + /* Aliases */ + + mkldnn_x = mkldnn_a, + mkldnn_nc = mkldnn_ab, + mkldnn_cn = mkldnn_ba, + mkldnn_ncw = mkldnn_abc, + mkldnn_nwc = mkldnn_acb, + mkldnn_nchw = mkldnn_abcd, + mkldnn_nhwc = mkldnn_acdb, + mkldnn_chwn = mkldnn_bcda, + mkldnn_ncdhw = mkldnn_abcde, + mkldnn_ndhwc = mkldnn_acdeb, + + mkldnn_oi = mkldnn_ab, + mkldnn_io = mkldnn_ba, + mkldnn_oiw = mkldnn_abc, + mkldnn_wio = mkldnn_cba, + mkldnn_oihw = mkldnn_abcd, + mkldnn_hwio = mkldnn_cdba, + mkldnn_ihwo = mkldnn_bcda, + mkldnn_iohw = mkldnn_bacd, + mkldnn_oidhw = mkldnn_abcde, + mkldnn_dhwio = mkldnn_cdeba, + mkldnn_goiw = mkldnn_abcd, + mkldnn_goihw = mkldnn_abcde, + mkldnn_hwigo = mkldnn_decab, + mkldnn_giohw = mkldnn_acbde, + mkldnn_goidhw = mkldnn_abcdef, + + /** 3D RNN data tensor in the format (seq_length, batch, input channels). */ + mkldnn_tnc = mkldnn_abc, + /** 3D RNN data tensor in the format (batch, seq_length, input channels). */ + mkldnn_ntc = mkldnn_bac, + /** 5D RNN states tensor in the format (num_layers, num_directions, + * num_states, batch, state channels). */ + mkldnn_ldsnc = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * input_channels, num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldigo = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * num_gates, output_channels, input_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgoi = mkldnn_abdec, + /** 4D RNN bias tensor in the format (num_layers, num_directions, + * num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgo = mkldnn_abcd, + + /* Opaque data types, are not to be used explicitly */ + + /* data */ + mkldnn_nCdhw16c = mkldnn_aBcde16b, + mkldnn_nCdhw4c = mkldnn_aBcde4b, + mkldnn_nCdhw8c = mkldnn_aBcde8b, + mkldnn_nChw16c = mkldnn_aBcd16b, + mkldnn_nChw4c = mkldnn_aBcd4b, + mkldnn_nChw8c = mkldnn_aBcd8b, + mkldnn_nCw16c = mkldnn_aBc16b, + mkldnn_nCw4c = mkldnn_aBc4b, + mkldnn_nCw8c = mkldnn_aBc8b, + + /* weights, 3D */ + mkldnn_IOw16o16i = mkldnn_BAc16a16b, + mkldnn_OIw16i16o = mkldnn_ABc16b16a, + mkldnn_OIw16o16i = mkldnn_ABc16a16b, + mkldnn_Oiw16o = mkldnn_Abc16a, + mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b, + mkldnn_OIw4i4o = mkldnn_ABc4b4a, + mkldnn_Oiw4o = mkldnn_Abc4a, + mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b, + mkldnn_OIw8i8o = mkldnn_ABc8b8a, + mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a, + mkldnn_OIw8o8i = mkldnn_ABc8a8b, + mkldnn_Owi16o = mkldnn_Acb16a, + mkldnn_Owi4o = mkldnn_Acb4a, + mkldnn_Owi8o = mkldnn_Acb8a, + + /* weights, 4D */ + mkldnn_IOhw16o16i = mkldnn_BAcd16a16b, + mkldnn_Ohwi16o = mkldnn_Acdb16a, + mkldnn_Ohwi4o = mkldnn_Acdb4a, + mkldnn_Ohwi8o = mkldnn_Acdb8a, + mkldnn_OIhw16i16o = mkldnn_ABcd16b16a, + mkldnn_OIhw16o16i = mkldnn_ABcd16a16b, + mkldnn_Oihw16o = mkldnn_Abcd16a, + mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b, + mkldnn_OIhw4i4o = mkldnn_ABcd4b4a, + mkldnn_Oihw4o = mkldnn_Abcd4a, + mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b, + mkldnn_OIhw8i8o = mkldnn_ABcd8b8a, + mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a, + mkldnn_OIhw8o8i = mkldnn_ABcd8a8b, + + /* weights, 5D */ + mkldnn_Odhwi16o = mkldnn_Acdeb16a, + mkldnn_Odhwi4o = mkldnn_Acdeb4a, + mkldnn_Odhwi8o = mkldnn_Acdeb8a, + mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a, + mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b, + mkldnn_Oidhw16o = mkldnn_Abcde16a, + mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a, + mkldnn_Oidhw4o = mkldnn_Abcde4a, + mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b, + mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a, + mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b, + + /* weights w/ groups, 3D */ + mkldnn_Goiw16g = mkldnn_Abcd16a, + mkldnn_gIOw16o16i = mkldnn_aCBd16b16c, + mkldnn_gOIw16i16o = mkldnn_aBCd16c16b, + mkldnn_gOIw16o16i = mkldnn_aBCd16b16c, + mkldnn_gOiw16o = mkldnn_aBcd16b, + mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c, + mkldnn_gOIw4i4o = mkldnn_aBCd4c4b, + mkldnn_gOiw4o = mkldnn_aBcd4b, + mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c, + mkldnn_gOIw8i8o = mkldnn_aBCd8c8b, + mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b, + mkldnn_gOIw8o8i = mkldnn_aBCd8b8c, + mkldnn_gOwi16o = mkldnn_aBdc16b, + mkldnn_gOwi4o = mkldnn_aBdc4b, + mkldnn_gOwi8o = mkldnn_aBdc8b, + + /* weights w/ groups, 4D */ + mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c, + mkldnn_gOhwi16o = mkldnn_aBdec16b, + mkldnn_gOhwi4o = mkldnn_aBdec4b, + mkldnn_gOhwi8o = mkldnn_aBdec8b, + mkldnn_Goihw16g = mkldnn_Abcde16a, + mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b, + mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c, + mkldnn_gOihw16o = mkldnn_aBcde16b, + mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c, + mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c, + mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b, + mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c, + mkldnn_gOihw4o = mkldnn_aBcde4b, + mkldnn_Goihw8g = mkldnn_Abcde8a, + mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c, + mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b, + mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b, + mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c, + + /* weights w/ groups, 6D */ + mkldnn_gOdhwi16o = mkldnn_aBdefc16b, + mkldnn_gOdhwi4o = mkldnn_aBdefc4b, + mkldnn_gOdhwi8o = mkldnn_aBdefc8b, + mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b, + mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c, + mkldnn_gOidhw16o = mkldnn_aBcdef16b, + mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b, + mkldnn_gOidhw4o = mkldnn_aBcdef4b, + mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c, + mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b, + mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c, +} mkldnn_format_tag_t; + +/** Kinds of padding. Define how to interpret the data in padding regions. */ +typedef enum { + /** The data in padding regions is zero. */ + mkldnn_padding_zero, +} mkldnn_padding_kind_t; + +/** Kinds of propagation. */ +typedef enum { + /* TODO: suggest renames */ + /** Undefined propagation type. */ + mkldnn_prop_kind_undef = 0, + /** Forward data propagation (training mode). In this mode primitives + * perform computations necessary for subsequent backward propagation. */ + mkldnn_forward_training = 64, + /** Forward data propagation (inference mode). In this mode primitives + * perform only computations that are necessary for inference and omit + * computations that are necessary only for backward propagation. */ + mkldnn_forward_inference = 96, + /** Forward data propagation (alias for @c mkldnn_forward_inference) */ + mkldnn_forward_scoring = mkldnn_forward_inference, + /** Forward data propagation (alias for @c mkldnn_forward_training) */ + mkldnn_forward = mkldnn_forward_training, + /** Backward propagation (with respect to all parameters */ + mkldnn_backward = 128, + /** Backward data propagation */ + mkldnn_backward_data = 160, + /** Backward weights propagation */ + mkldnn_backward_weights = 192, + /** Backward bias propagation */ + mkldnn_backward_bias = 193, +} mkldnn_prop_kind_t; + +/** Kinds of primitives. Used to implement a way to extend the library with new + * primitives without changing the ABI. */ +typedef enum { + /** Undefined primitive (XXX: why do we have it?). */ + mkldnn_undefined_primitive, + /** A reorder primitive.*/ + mkldnn_reorder, + /** A shuffle primitive.*/ + mkldnn_shuffle, + /** A (out-of-place) concat primitive. */ + mkldnn_concat, + /** A sum primitive. */ + mkldnn_sum, + /** A convolution primitive. */ + mkldnn_convolution, + /** A deconvolution primitive. */ + mkldnn_deconvolution, + /** An element-wise primitive. */ + mkldnn_eltwise, + /** A Softmax primitive. */ + mkldnn_softmax, + /** A pooling primitive. */ + mkldnn_pooling, + /** An LRN primitive. */ + mkldnn_lrn, + /** An batch normalization primitive. */ + mkldnn_batch_normalization, + /** An inner product primitive. */ + mkldnn_inner_product, + /** A rnn primitive. */ + mkldnn_rnn, +} mkldnn_primitive_kind_t; + +/** Kinds of algorithms. */ +typedef enum { + mkldnn_alg_kind_undef, + /** Direct convolution */ + mkldnn_convolution_direct = 0x1, + /** Winograd convolution */ + mkldnn_convolution_winograd = 0x2, + /** Convolution algorithm(either direct or Winograd) is chosen just in time **/ + mkldnn_convolution_auto = 0x3, + /** Direct deconvolution */ + mkldnn_deconvolution_direct = 0xa, + /** Winograd deconvolution */ + mkldnn_deconvolution_winograd = 0xb, + /** Eltwise: ReLU */ + mkldnn_eltwise_relu = 0x1f, + /** Eltwise: hyperbolic tangent non-linearity (tanh) */ + mkldnn_eltwise_tanh = 0x2f, + /** Eltwise: parametric exponential linear unit (elu) */ + mkldnn_eltwise_elu = 0x3f, + /** Eltwise: square */ + mkldnn_eltwise_square = 0x4f, + /** Eltwise: abs */ + mkldnn_eltwise_abs = 0x5f, + /** Eltwise: square root */ + mkldnn_eltwise_sqrt = 0x6f, + /** Eltwise: linear */ + mkldnn_eltwise_linear = 0x7f, + /** Eltwise: bounded_relu */ + mkldnn_eltwise_bounded_relu = 0x8f, + /** Eltwise: soft_relu */ + mkldnn_eltwise_soft_relu = 0x9f, + /** Eltwise: logistic */ + mkldnn_eltwise_logistic = 0xaf, + /** Max pooling */ + mkldnn_pooling_max = 0x1ff, + /** Average pooling include padding */ + mkldnn_pooling_avg_include_padding = 0x2ff, + /** Average pooling exclude padding */ + mkldnn_pooling_avg_exclude_padding = 0x3ff, + mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding, + /** Local response normalization (LRN) across multiple channels */ + mkldnn_lrn_across_channels = 0xaff, + /** LRN within a single channel */ + mkldnn_lrn_within_channel = 0xbff, + /** RNN cell */ + mkldnn_vanilla_rnn = 0x1fff, + /** LSTM cell */ + mkldnn_vanilla_lstm = 0x2fff, + /** GRU cell */ + mkldnn_vanilla_gru = 0x3fff, + /** GRU cell with linear before reset + * + * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru + * in how the new memory gate is calculated: + * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f] + * Primitive expects 4 biases on input: + * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ + * */ + mkldnn_gru_linear_before_reset = 0x4fff, +} mkldnn_alg_kind_t; + +/** Flags for batch-normalization primititve. */ +typedef enum { + /** Use global statistics + * + * If specified + * - on forward propagation use mean and variance provided by user (input) + * - on backward propagation reduces the amount of computations, since + * mean and variance are considered as constants + * + * If not specified: + * - on forward propagation mean and variance are computed and stored in + * output + * - on backward propagation compute full derivative wrt to data + */ + mkldnn_use_global_stats = 0x1U, + /** Use scale and shift parameters + * + * If specified: + * - on forward propagation use scale and shift (aka scale and bias) for + * the batch normalization results + * - on backward propagation (for prop_kind == #mkldnn_backward) compute + * diff wrt to scale and shift (hence one extra output used) + * + * If no specified: + * - on backward propagation prop_kind == #mkldnn_backward_data has the + * same behavior as prop_kind == #mkldnn_backward + */ + mkldnn_use_scaleshift = 0x2U, + /** Fuse with ReLU + * + * If specified: + * - on inference this option behaves the same as if the primitive were + * fused with ReLU via post ops API + * - on training primitive requires workspace (required to be able to + * perform backward pass) + */ + mkldnn_fuse_bn_relu = 0x4U, +} mkldnn_batch_normalization_flag_t; + +/** @} */ + +/** @addtogroup c_api_types_memory Memory + * @{ */ + +/** Maximum number of dimensions a tensor can have. Only restricts the amount + * of space used for the tensor description. Individual computational + * primitives may support only tensors of certain dimensions. */ +#define MKLDNN_MAX_NDIMS 12 + +/** A type to describe tensor dimension. */ +typedef int64_t mkldnn_dim_t; + +/** A type to describe tensor dimensions. */ +typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS]; + +/** A type to describe strides within a tensor. */ +typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS]; + +/** Generic description of blocked data layout for most memory formats. + * + * @sa @ref understanding_memory_formats */ +typedef struct { + /** The strides between the outermost blocks. + * In case of plain (non-blocked) formats the strides between dimensions. */ + mkldnn_dims_t strides; + /* Innermost section + * ASSUMPTION: the innermost blocks are always dense */ + /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */ + int inner_nblks; + /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */ + mkldnn_dims_t inner_blks; + /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of + * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */ + mkldnn_dims_t inner_idxs; +} mkldnn_blocking_desc_t; + +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_wino_undef = 0, + /** Tensors of weights for 2x3 winograd convolutions. */ + mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, + /** Tensor of weights for 4x3 convolution. */ + mkldnn_wino_wei_OBaaIBOIio +} mkldnn_wino_memory_format_t; + +/** Description of tensor of weights for winograd 2x3 convolution. */ +typedef struct { + mkldnn_wino_memory_format_t wino_format; + int r; + int alpha; + int ic; + int oc; + int ic_block; + int oc_block; + int ic2_block; + int oc2_block; + float adj_scale; + size_t size; +} mkldnn_wino_desc_t; + +typedef enum { + mkldnn_packed_format_undef = 0, + mkldnn_ldigo_p, + mkldnn_ldgoi_p +} mkldnn_rnn_packed_memory_format_t; + +/* Maximum number of parts of RNN weights tensor that require separate + * computation. */ +#define MKLDNN_RNN_MAX_N_PARTS 4 + +/** Description of tensor of packed weights for rnn. */ +typedef struct { + mkldnn_rnn_packed_memory_format_t format; + int n_parts; + int n; + int parts[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + size_t offset_compensation; + size_t size; +} mkldnn_rnn_packed_desc_t; + +typedef enum { + mkldnn_memory_extra_flag_none = 0x0U, + /** Indicates the weights have an additional buffer, that depends on the + * @p compensation_mask. + * + * For instance, in 4D case with the compensation mask equals (1 << 0) + * the additional buffer would consist of OC values: + * O[oc : 0,OC] = + * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } + */ + mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U, + mkldnn_memory_extra_flag_scale_adjust = 0x2U, +} mkldnn_memory_extra_flags_t; + +/** Description of extra information stored in memory */ +typedef struct { + /** The flags contain arbitrary extra information, such as compensation. + * @sa mkldnn_memory_extra_flags_t */ + uint64_t flags; + /** Compensation mask */ + int compensation_mask; + /** Scale applied to the data */ + float scale_adjust; + /** For future backwards compatibility */ + char reserved[64]; +} mkldnn_memory_extra_desc_t; + +/** Memory descriptor. The description is based on a number of dimensions, + * dimensions themselves, plus information about elements type and memory + * format. Additionally, contains format-specific descriptions of the data + * layout. */ +typedef struct { + /** Number of dimensions */ + int ndims; + /** Dimensions in the following order: + * - CNN data tensors: mini-batch, channel, spatial + * ({N, C, [[D,] H,] W}) + * - CNN weight tensors: group (optional), output channel, input channel, + * spatial ({[G,] O, I, [[D,] H,] W}) + * - RNN data tensors: time, mini-batch, channels ({T, N, C}) + * or layers, directions, states, mini-batch, channels ({L, D, S, N, C}) + * - RNN weight tensor: layers, directions, input channel, gates, output channels + * ({L, D, I, G, O}). + * + * @note + * The order of dimensions does not depend on the memory format, so + * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc + * the dims for 4D CN data tensor would be {N, C, H, W}. + */ + mkldnn_dims_t dims; + /** Data type of the tensor elements. */ + mkldnn_data_type_t data_type; + + /** Size of the data including padding in each dimension. */ + mkldnn_dims_t padded_dims; + /** Per-dimension offset from the padding to actual data, the top-level + * tensor with offsets applied must lie within the padding area. */ + mkldnn_dims_t padded_offsets; + + /** Offset from memory origin to the current block, non-zero only in + * a description of a memory sub-block. */ + mkldnn_dim_t offset0; + + /** Memory format kind. */ + mkldnn_format_kind_t format_kind; + union { + /** Description of the data layout for memory formats that use + * blocking. */ + mkldnn_blocking_desc_t blocking; + /** Tensor of weights for integer 8bit winograd convolution. */ + mkldnn_wino_desc_t wino_desc; + /** Tensor of packed weights for RNN. */ + mkldnn_rnn_packed_desc_t rnn_packed_desc; + /* ... other descriptions possible */ + } format_desc; + + mkldnn_memory_extra_desc_t extra; +} mkldnn_memory_desc_t; + +/** @struct mkldnn_memory + * An opaque structure to describe a memory. */ +struct mkldnn_memory; + +/** A memory handle. */ +typedef struct mkldnn_memory *mkldnn_memory_t; + +/** A constant memory handle. */ +typedef const struct mkldnn_memory *const_mkldnn_memory_t; + +#define MKLDNN_NATIVE_HANDLE_NONE (NULL) +#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1) + +/** @} */ + +/** @addtogroup c_api_types_op_descs Operation descriptors + * @{*/ + +/** A pointer to any of the operation descriptors. */ +typedef void *mkldnn_op_desc_t; +/** A pointer to any of the operation descriptors (constant variant). */ +typedef const void *const_mkldnn_op_desc_t; + +/** A descriptor of a convolution operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** The kind of the convolution algorithm. Possible values: + * #mkldnn_convolution_direct. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Convolution strides in each spatial dimension. */ + mkldnn_dims_t strides; + /** Convolution dilates in each spatial dimension. */ + mkldnn_dims_t dilates; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_convolution_desc_t; + +/** A descriptor of a deconvolution operation. */ +typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t; + +/** A descriptor of a shuffle operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward_data. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor, + * and source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** axis for shuffling. */ + int axis; + /** number of groups in group convolution */ + mkldnn_dim_t group_size; +} mkldnn_shuffle_desc_t; + +/** A descriptor of a element-wise operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_eltwise. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu, + * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square, + * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear, + * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and + * #mkldnn_eltwise_logistic. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Algorithm specific parameter. + * Accordance table: + * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored + * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_square: @p alpha and @p beta ignored + * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored + * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored + * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift + * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored + * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored + * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored + */ + float alpha, beta; +} mkldnn_eltwise_desc_t; + +/** A descriptor of a Softmax operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_softmax. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training and + * #mkldnn_forward_inference. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and Destination of gradient memory descriptor. */ + mkldnn_memory_desc_t diff_desc; + /** The axis along which to perform the softmax. */ + int softmax_axis; +} mkldnn_softmax_desc_t; + +/** A descriptor of a pooling operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_pooling. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and + * #mkldnn_pooling_avg. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Pooling kernel strides for spatial dimensions. */ + mkldnn_dims_t strides; + /** Pooling kernel spatial dimensions. */ + mkldnn_dims_t kernel; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_pooling_desc_t; + +/** A descriptor of a Local Response Normalization (LRN) operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_lrn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and + * #mkldnn_lrn_across_channels. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** The number of channels to sum over (for cross-channel LRN) or the side + * length of the square region to sum over (for within-channel LRN). */ + mkldnn_dim_t local_size; + /** LRN alpha parameter. */ + float lrn_alpha; + /** LRN beta parameter. */ + float lrn_beta; + /** LRN k parameter. */ + float lrn_k; +} mkldnn_lrn_desc_t; + +/** A descriptor of a Batch Normalization operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_batch_normalization. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Scale and shift data and gradient memory descriptors. + * + * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st + * dimension contains gamma parameter, 2-nd dimension contains beta + * parameter. */ + mkldnn_memory_desc_t data_scaleshift_desc; + mkldnn_memory_desc_t diff_data_scaleshift_desc; + /** Mean and variance data memory descriptors. + * + * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels]. + */ + mkldnn_memory_desc_t mean_desc; + mkldnn_memory_desc_t variance_desc; + /** Batch normalization epsilon parameter. */ + float batch_norm_epsilon; + unsigned flags; +} mkldnn_batch_normalization_desc_t; + +/** A descriptor of an inner product operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_inner_product. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_inner_product_desc_t; + +/** Flags for RNN cell. */ +typedef enum { + mkldnn_rnn_cell_with_relu = 0x1U, + mkldnn_rnn_cell_with_clipping = 0x2U, +} mkldnn_rnn_cell_flags_t; + +typedef struct { + /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn, + * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, + * or #mkldnn_gru_linear_before_reset. */ + mkldnn_alg_kind_t cell_kind; + /** Activation function used. Must be either #mkldnn_eltwise_relu or + * #mkldnn_eltwise_tanh. */ + mkldnn_alg_kind_t activation_kind; + /** RNN cell flags */ + unsigned int flags; + /** @c alpha is a negative slope parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */ + float alpha; + /** clipping parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */ + float clipping; +} mkldnn_rnn_cell_desc_t; + +/** A direction of RNN primitive execution. */ +typedef enum { + /* Unidirectional execution of RNN primitive from left to right. */ + mkldnn_unidirectional_left2right, + /* Unidirectional execution of RNN primitive from right to left. */ + mkldnn_unidirectional_right2left, + /* Bidirectional execution of RNN primitive with concatenation of the + * results. */ + mkldnn_bidirectional_concat, + /* Bidirectional execution of RNN primitive with summation of the + * results. */ + mkldnn_bidirectional_sum, + mkldnn_unidirectional = mkldnn_unidirectional_left2right, +} mkldnn_rnn_direction_t; + +/** A descriptor for an RNN operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_rnn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward. */ + mkldnn_prop_kind_t prop_kind; + /** The RNN cell desc. */ + mkldnn_rnn_cell_desc_t cell_desc; + /** The direction of RNN primitive execution. */ + mkldnn_rnn_direction_t direction; + /** Source layer memory descriptor. */ + mkldnn_memory_desc_t src_layer_desc; + /** Source iteration memory descriptor. */ + mkldnn_memory_desc_t src_iter_desc; + /** Weights layer memory descriptor. */ + mkldnn_memory_desc_t weights_layer_desc; + /** Weights iteration memory descriptor. */ + mkldnn_memory_desc_t weights_iter_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Destination layer memory descriptor. */ + mkldnn_memory_desc_t dst_layer_desc; + /** Destination iter memory descriptor. */ + mkldnn_memory_desc_t dst_iter_desc; + /** Source gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_src_layer_desc; + /** Source gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_src_iter_desc; + /** Weights gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_weights_layer_desc; + /** Weights gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_weights_iter_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_dst_layer_desc; + /** Destination gradient iteration memory descriptor. */ + mkldnn_memory_desc_t diff_dst_iter_desc; +} mkldnn_rnn_desc_t; + +/** @} */ + +/** @addtogroup c_api_engine_types Engine + * @{ */ + +/** @brief Kinds of engines. */ +typedef enum { + /** An unspecified engine. */ + mkldnn_any_engine, + /** CPU engine. */ + mkldnn_cpu, +} mkldnn_engine_kind_t; + +/** @struct mkldnn_engine + * @brief An opaque structure to describe an engine. */ +struct mkldnn_engine; +/** @brief An engine handle. */ +typedef struct mkldnn_engine *mkldnn_engine_t; +#if 0 +/* FIXME: looks like this never happens */ +/** @brief A constant engine handle. */ +typedef const struct mkldnn_engine *const_mkldnn_engine_t; +#endif + +/** @} */ + +/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators + * @{ */ + +/** @struct mkldnn_primitive_desc_iterator + * @brief An opaque structure to describe a primitive descriptor iterator. */ +struct mkldnn_primitive_desc_iterator; + +/** @brief A primitive descriptor iterator handle. */ +typedef struct mkldnn_primitive_desc_iterator + *mkldnn_primitive_desc_iterator_t; + +/** @brief A constant primitive descriptor iterator handle. */ +typedef const struct mkldnn_primitive_desc_iterator + *const_mkldnn_primitive_desc_iterator_t; + +/** @} */ + +/** @addtogroup c_api_primitive_descs Primitive descriptors + * @{ */ + +/** @struct mkldnn_primitive_desc + * @brief An opaque structure to describe a primitive descriptor. */ +struct mkldnn_primitive_desc; + +/** @brief A primitive descriptor handle. */ +typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t; + +/** @brief A constant primitive descriptor handle. */ +typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t; + +/** @} */ + +/** @addtogroup c_api_primitive_attr Primitive descriptor attributes + * @{ */ + +/** Scratchpad mode */ +typedef enum { + /** The library manages scratchpad (default) */ + mkldnn_scratchpad_mode_library, + /** A user shall query and provide the scratchpad memory to primitives */ + mkldnn_scratchpad_mode_user, +} mkldnn_scratchpad_mode_t; + +/** @struct mkldnn_primitive_attr + * @brief An opaque structure for primitive descriptor attributes. + * + * Attributes may contain: + * - output scales (to scale the result prior to storing it to the memory) + */ +struct mkldnn_primitive_attr; + +/** @brief A primitive descriptor attributes handle that controls primitive + * behavior. */ +typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t; + +/** @brief A constant primitive descriptor attributes handle. */ +typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t; + +/** @struct mkldnn_post_ops + * @brief An opaque structure for a chain of post operations. + * + * mkldnn_post_ops can be used to perform some (trivial) operations like + * accumulation or eltwise after certain primitives like convolution. + * + * Post operations might be combined together, making a chain of post + * operations. For instance one can configure convolution followed by + * accumulation followed by eltwise. This might be especially beneficial + * for residual learning blocks. + * + * @warning + * Of course not all combinations are supported, so the user should handle + * errors accordingly. + * + * Supported post operations: + * - accumulation (base primitive: convolution) + * - eltwise (base primitive: convolution) + */ +struct mkldnn_post_ops; + +/** @brief A post operation chain handle. */ +typedef struct mkldnn_post_ops *mkldnn_post_ops_t; + +/** @brief A constant post operation chain handle. */ +typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t; + +/** @} */ + +/** @addtogroup c_api_types_primitive Primitive + * @{ */ + +/** @struct mkldnn_primitive + * An opaque structure to describe a primitive. */ +struct mkldnn_primitive; +/** A primitive handle. */ +typedef struct mkldnn_primitive *mkldnn_primitive_t; +/** A constant primitive handle. */ +typedef const struct mkldnn_primitive *const_mkldnn_primitive_t; + +/** @addtogroup c_api_types_arguments Argument indices + * @{ */ + +#define MKLDNN_ARG_SRC_0 1 +#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0 + +#define MKLDNN_ARG_SRC_1 2 +#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1 + +#define MKLDNN_ARG_DST_0 17 +#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0 + +#define MKLDNN_ARG_DST_1 18 +#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1 + +#define MKLDNN_ARG_WEIGHTS_0 33 +#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0 + +#define MKLDNN_ARG_WEIGHTS_1 34 +#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1 + +#define MKLDNN_ARG_BIAS 41 + +#define MKLDNN_ARG_MEAN 49 +#define MKLDNN_ARG_VARIANCE 50 + +#define MKLDNN_ARG_WORKSPACE 64 +#define MKLDNN_ARG_SCRATCHPAD 80 + +#define MKLDNN_ARG_DIFF_SRC_0 129 +#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0 +#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0 + +#define MKLDNN_ARG_DIFF_SRC_1 130 +#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1 + +#define MKLDNN_ARG_DIFF_DST_0 145 +#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0 +#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0 + +#define MKLDNN_ARG_DIFF_DST_1 146 +#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1 + +#define MKLDNN_ARG_DIFF_WEIGHTS_0 161 +#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0 + +#define MKLDNN_ARG_DIFF_WEIGHTS_1 162 +#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1 + +#define MKLDNN_ARG_DIFF_BIAS 169 + +#define MKLDNN_ARG_MULTIPLE_SRC 1024 +#define MKLDNN_ARG_MULTIPLE_DST 2048 + +/** @} */ + +/** An auxiliary structure to specify primitive's inputs/outputs at execution + * + * @warning + * With this API it's impossible to preserve constness of memory, so all + * memories are passed w/o const qualifier. However only memories with + * output semantics might be changed during the execution */ +typedef struct { + int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */ + mkldnn_memory_t memory; /**< Input/output memory */ +} mkldnn_exec_arg_t; + +/** @} */ + +/** @addtogroup c_api_types_query Queries + * @{ */ + +/** Primitive descriptor query specification + * + * For generic function mkldnn_primitive_desc_query(), the type of result must + * agree with the queried argument. The correspondence table: + * Query | type of result + * -------------------------------------------------------------- + * #mkldnn_query_engine | mkldnn_engine_t * + * #mkldnn_query_scratchpad_engine | mkldnn_engine_t * + * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t * + * *_s32 | int * + * *_s64 | mkldnn_dim_t * (same as int64_t *) + * *_f64 | double * + * *_str | const char ** + * #mkldnn_query_op_d | const_mkldnn_op_desc_t * + * *_md | const mkldnn_memory_desc_t ** + * *_${op}_d | const mkldnn_${op}_desc_t ** + * *_pd | const_mkldnn_primitive_desc_t * + * + * @note + * Rule of thumb: all opaque types and structures are returned by + * reference. All numbers are returned by value. + * + * @warning + * All returned references point to constant objects and are valid only + * during the lifetime of the queried primitive descriptor. Returned objects + * must not be destroyed by the user. If you need to keep the object longer + * than the lifetime of the queried primitive descriptor, use + * mkldnn_primitive_desc_clone() to make a copy. */ +typedef enum { + mkldnn_query_undef = 0, /**< no query */ + + mkldnn_query_engine, /**< execution engine */ + mkldnn_query_primitive_kind, /**< primitive kind */ + + mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */ + mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */ + + mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */ + mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra + (scratch) memory, additional to all + inputs and outputs memory (bytes) */ + + mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used + for creating scratchpad memory */ + + mkldnn_query_impl_info_str, /**< implementation name */ + + /* memory and op descriptor section */ + mkldnn_query_some_d = 64, /**< stub */ + mkldnn_query_op_d, /**< op descriptor */ + mkldnn_query_convolution_d, /**< convolution descriptor */ + mkldnn_query_deconvolution_d, /**< deconvolution descriptor */ + mkldnn_query_shuffle_d, /**< shuffle descriptor */ + mkldnn_query_eltwise_d, /**< eltwise descriptor */ + mkldnn_query_softmax_d, /**< softmax descriptor */ + mkldnn_query_pooling_d, /**< pooling descriptor */ + mkldnn_query_lrn_d, /**< lrn descriptor */ + mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */ + mkldnn_query_inner_product_d, /**< inner product descriptor */ + mkldnn_query_rnn_d, /**< rnn descriptor */ + + /* memory descriptor section */ + mkldnn_query_some_md = 128, /**< stub */ + mkldnn_query_src_md, /**< source memory desc */ + mkldnn_query_diff_src_md, /**< source gradient memory desc */ + mkldnn_query_weights_md, /**< weights memory descriptor desc */ + mkldnn_query_diff_weights_md, /**< weights grad. memory desc */ + mkldnn_query_dst_md, /**< destination memory desc */ + mkldnn_query_diff_dst_md, /**< destination grad. memory desc */ + mkldnn_query_workspace_md, /**< workspace memory desc */ + mkldnn_query_scratchpad_md, /**< scratchpad memory desc */ +} mkldnn_query_t; + +/** @} */ + +/** @addtogroup c_api_types_stream Execution stream + * @{ */ + +/** @brief Stream flags. */ +typedef enum { + /** A default stream configuration. */ + mkldnn_stream_default_flags = 0x0U, +} mkldnn_stream_flags_t; + +/** @struct mkldnn_stream + * An opaque structure to describe an execution stream. */ +struct mkldnn_stream; +/** An execution stream handle. */ +typedef struct mkldnn_stream *mkldnn_stream_t; +/** A constant execution stream handle. */ +typedef const struct mkldnn_stream *const_mkldnn_stream_t; + +/** @} */ +/** @} */ +/** @} */ + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h new file mode 100644 index 0000000000..a2713deccb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR 0 + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR 90 + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH 0 + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in new file mode 100644 index 0000000000..5ee0126188 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@ + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@ + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@ + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp new file mode 100644 index 0000000000..1a51d8562b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) { + bool args_ok = true + && !any_null(bnrm_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data, backward) + && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto bd = batch_normalization_desc_t(); + bd.primitive_kind = primitive_kind::batch_normalization; + bd.prop_kind = prop_kind; + + bd.data_desc = *data_desc; + bd.diff_data_desc = zero_md(); + if ( one_of(bd.prop_kind,backward_data, backward) ) + bd.diff_data_desc = *diff_data_desc; + + dims_t scaleshift_dims = { 2, data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2, + scaleshift_dims, data_type::f32, mkldnn_nc); + bd.diff_data_scaleshift_desc = zero_md(); + if (bd.prop_kind == backward) { + bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc; + } + + dims_t stats_dims = { data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims, + data_type::f32, mkldnn_x); + bd.variance_desc = bd.mean_desc; + bd.batch_norm_epsilon = epsilon; + + unsigned bnorm_flags = + mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu; + if ((~bnorm_flags & flags) != 0) return invalid_arguments; + + bd.flags = flags; + + bool consistency = true + && utils::one_of(bd.data_desc.ndims, 2, 4, 5); + if (bd.prop_kind == backward_data) + consistency = consistency + && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5) + && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims, + bd.diff_data_desc.ndims); + if (!consistency) return invalid_arguments; + + *bnrm_desc = bd; + return success; +} +} + +status_t mkldnn_batch_normalization_forward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, float epsilon, unsigned flags) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr, + epsilon, flags); +} + +status_t mkldnn_batch_normalization_backward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc, + float epsilon, unsigned flags) { + if (!one_of(prop_kind, backward, backward_data)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc, + epsilon, flags); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp new file mode 100644 index 0000000000..f61410b33c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp @@ -0,0 +1,240 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef BATCH_NORMALIZATION_PD_HPP +#define BATCH_NORMALIZATION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct batch_normalization_fwd_pd_t; + +struct batch_normalization_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::batch_normalization; + + batch_normalization_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , stat_md_(desc_.mean_desc) + , scaleshift_md_(desc_.data_scaleshift_desc) + , ws_md_() + {} + + const batch_normalization_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::batch_normalization_d: + *(const batch_normalization_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common batch_normalization aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return desc_.data_desc.ndims; } + + bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; } + bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; } + bool use_global_stats() const + { return desc_.flags & mkldnn_use_global_stats; } + bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; } + bool with_relu_post_op() const { + const auto &p = this->attr()->post_ops_; + return p.len_ == 1 && p.entry_[0].is_relu(true, true); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + bool is_bwd() const { return !this->is_fwd(); } + bool is_training() const + { return desc_.prop_kind == prop_kind::forward_training; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + batch_normalization_desc_t desc_; + const batch_normalization_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t stat_md_; + memory_desc_t scaleshift_md_; + + memory_desc_t ws_md_; + + void init_default_ws(size_t bits_per_element) { + const auto data_mdw = memory_desc_wrapper(data_md_); + + const dim_t data_nelems = data_mdw.nelems(true); + const dim_t bits_per_byte = 8; + const dims_t ws_sz = { (dim_t)utils::div_up( + data_nelems * bits_per_element, bits_per_byte) }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8, + format_tag::x); + } + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_fwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_fwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input; + if (arg == MKLDNN_ARG_DST) return arg_usage_t::output; + + if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) { + if (stats_is_src()) return arg_usage_t::input; + if (!stats_is_src() && is_training()) return arg_usage_t::output; + return arg_usage_t::unused; + } + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (!stats_is_src() && is_training() && (index == 1 || index == 2)) + return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const + { return stats_is_src() ? src_md(1) : dst_md(1); } + + virtual int n_inputs() const override + { return 1 + 2 * stats_is_src() + use_scaleshift(); } + virtual int n_outputs() const override + { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); } +}; + +struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_bwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_bwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN, + MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override + { return index == 0 ? &diff_scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const { return src_md(1); } + + virtual int n_inputs() const override + { return 4 + use_scaleshift() + fuse_bn_relu(); } + virtual int n_outputs() const override + { return 1 + (desc_.prop_kind == prop_kind::backward); } + +protected: + memory_desc_t diff_data_md_; + memory_desc_t diff_scaleshift_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp new file mode 100644 index 0000000000..3d43a0fbee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp @@ -0,0 +1,550 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_MAPPING_HPP +#define TYPE_MAPPING_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { + +// TODO: autogenerate this + +using dim_t = mkldnn_dim_t; +using dims_t = mkldnn_dims_t; +using stride_t = mkldnn_dim_t; +using strides_t = mkldnn_strides_t; + +using status_t = mkldnn_status_t; +namespace status { + const status_t success = mkldnn_success; + const status_t out_of_memory = mkldnn_out_of_memory; + const status_t try_again = mkldnn_try_again; + const status_t invalid_arguments = mkldnn_invalid_arguments; + const status_t not_ready = mkldnn_not_ready; + const status_t unimplemented = mkldnn_unimplemented; + const status_t iterator_ends = mkldnn_iterator_ends; + const status_t runtime_error = mkldnn_runtime_error; + const status_t not_required = mkldnn_not_required; +} + +using prop_kind_t = mkldnn_prop_kind_t; +namespace prop_kind { + const prop_kind_t undef = mkldnn_prop_kind_undef; + const prop_kind_t forward_training = mkldnn_forward_training; + const prop_kind_t forward_inference = mkldnn_forward_inference; + const prop_kind_t forward_scoring = mkldnn_forward_scoring; + const prop_kind_t forward = mkldnn_forward; + const prop_kind_t backward = mkldnn_backward; + const prop_kind_t backward_data = mkldnn_backward_data; + const prop_kind_t backward_weights = mkldnn_backward_weights; + const prop_kind_t backward_bias = mkldnn_backward_bias; +} + +using alg_kind_t = mkldnn_alg_kind_t; +namespace alg_kind { + const alg_kind_t undef = mkldnn_alg_kind_undef; + const alg_kind_t convolution_auto = mkldnn_convolution_auto; + const alg_kind_t convolution_direct = mkldnn_convolution_direct; + const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; + const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct; + const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd; + const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; + const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh; + const alg_kind_t eltwise_elu = mkldnn_eltwise_elu; + const alg_kind_t eltwise_square = mkldnn_eltwise_square; + const alg_kind_t eltwise_abs = mkldnn_eltwise_abs; + const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt; + const alg_kind_t eltwise_linear = mkldnn_eltwise_linear; + const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu; + const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu; + const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic; + const alg_kind_t pooling_max = mkldnn_pooling_max; + const alg_kind_t pooling_avg = mkldnn_pooling_avg; + const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding; + const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; + const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; + const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; + const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; + const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; + const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; + const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset; +} + +using data_type_t = mkldnn_data_type_t; +namespace data_type { + const data_type_t undef = mkldnn_data_type_undef; + const data_type_t f32 = mkldnn_f32; + const data_type_t s32 = mkldnn_s32; + const data_type_t s8 = mkldnn_s8; + const data_type_t u8 = mkldnn_u8; +} + +using scratchpad_mode_t = mkldnn_scratchpad_mode_t; +namespace scratchpad_mode { + const scratchpad_mode_t library = mkldnn_scratchpad_mode_library; + const scratchpad_mode_t user = mkldnn_scratchpad_mode_user; +} + +using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t; +namespace rnn_packed_format { + const rnn_packed_format_t undef = mkldnn_packed_format_undef; + const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p; + const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p; +} + +using format_kind_t = mkldnn_format_kind_t; +namespace format_kind { + const format_kind_t undef = mkldnn_format_kind_undef; + const format_kind_t any = mkldnn_format_kind_any; + const format_kind_t blocked = mkldnn_blocked; + const format_kind_t wino = mkldnn_format_kind_wino; + const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed; +} + +using format_tag_t = mkldnn_format_tag_t; +namespace format_tag { + const format_tag_t undef = mkldnn_format_tag_undef; + const format_tag_t any = mkldnn_format_tag_any; + const format_tag_t a = mkldnn_a; + const format_tag_t ab = mkldnn_ab; + const format_tag_t abc = mkldnn_abc; + const format_tag_t abcd = mkldnn_abcd; + const format_tag_t abcde = mkldnn_abcde; + const format_tag_t abcdef = mkldnn_abcdef; + const format_tag_t abdec = mkldnn_abdec; + const format_tag_t acb = mkldnn_acb; + const format_tag_t acbde = mkldnn_acbde; + const format_tag_t acdb = mkldnn_acdb; + const format_tag_t acdeb = mkldnn_acdeb; + const format_tag_t ba = mkldnn_ba; + const format_tag_t bac = mkldnn_bac; + const format_tag_t bacd = mkldnn_bacd; + const format_tag_t bcda = mkldnn_bcda; + const format_tag_t cba = mkldnn_cba; + const format_tag_t cdba = mkldnn_cdba; + const format_tag_t cdeba = mkldnn_cdeba; + const format_tag_t decab = mkldnn_decab; + const format_tag_t Abc16a = mkldnn_Abc16a; + const format_tag_t ABc16a16b = mkldnn_ABc16a16b; + const format_tag_t aBc16b = mkldnn_aBc16b; + const format_tag_t ABc16b16a = mkldnn_ABc16b16a; + const format_tag_t Abc4a = mkldnn_Abc4a; + const format_tag_t aBc4b = mkldnn_aBc4b; + const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b; + const format_tag_t ABc4b4a = mkldnn_ABc4b4a; + const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a; + const format_tag_t ABc8a8b = mkldnn_ABc8a8b; + const format_tag_t aBc8b = mkldnn_aBc8b; + const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b; + const format_tag_t ABc8b8a = mkldnn_ABc8b8a; + const format_tag_t Abcd16a = mkldnn_Abcd16a; + const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b; + const format_tag_t aBcd16b = mkldnn_aBcd16b; + const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a; + const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c; + const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b; + const format_tag_t Abcd4a = mkldnn_Abcd4a; + const format_tag_t aBcd4b = mkldnn_aBcd4b; + const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b; + const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a; + const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c; + const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b; + const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a; + const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b; + const format_tag_t aBcd8b = mkldnn_aBcd8b; + const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b; + const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b; + const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a; + const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c; + const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c; + const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b; + const format_tag_t Abcde16a = mkldnn_Abcde16a; + const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b; + const format_tag_t aBcde16b = mkldnn_aBcde16b; + const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a; + const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c; + const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b; + const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c; + const format_tag_t Abcde4a = mkldnn_Abcde4a; + const format_tag_t aBcde4b = mkldnn_aBcde4b; + const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a; + const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c; + const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c; + const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b; + const format_tag_t Abcde8a = mkldnn_Abcde8a; + const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b; + const format_tag_t aBcde8b = mkldnn_aBcde8b; + const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b; + const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b; + const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a; + const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c; + const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c; + const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b; + const format_tag_t aBcdef16b = mkldnn_aBcdef16b; + const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c; + const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b; + const format_tag_t aBcdef4b = mkldnn_aBcdef4b; + const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b; + const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c; + const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c; + const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b; + const format_tag_t aBdc16b = mkldnn_aBdc16b; + const format_tag_t aBdc4b = mkldnn_aBdc4b; + const format_tag_t aBdc8b = mkldnn_aBdc8b; + const format_tag_t aBdec16b = mkldnn_aBdec16b; + const format_tag_t aBdec4b = mkldnn_aBdec4b; + const format_tag_t aBdec8b = mkldnn_aBdec8b; + const format_tag_t aBdefc16b = mkldnn_aBdefc16b; + const format_tag_t aBdefc4b = mkldnn_aBdefc4b; + const format_tag_t aBdefc8b = mkldnn_aBdefc8b; + const format_tag_t Acb16a = mkldnn_Acb16a; + const format_tag_t Acb4a = mkldnn_Acb4a; + const format_tag_t Acb8a = mkldnn_Acb8a; + const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c; + const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c; + const format_tag_t Acdb16a = mkldnn_Acdb16a; + const format_tag_t Acdb4a = mkldnn_Acdb4a; + const format_tag_t Acdb8a = mkldnn_Acdb8a; + const format_tag_t Acdeb16a = mkldnn_Acdeb16a; + const format_tag_t Acdeb4a = mkldnn_Acdeb4a; + const format_tag_t Acdeb8a = mkldnn_Acdeb8a; + const format_tag_t BAc16a16b = mkldnn_BAc16a16b; + const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b; + const format_tag_t last = mkldnn_format_tag_last; + + const format_tag_t x = mkldnn_x; + const format_tag_t nc = mkldnn_nc; + const format_tag_t cn = mkldnn_cn; + const format_tag_t ncw = mkldnn_ncw; + const format_tag_t nwc = mkldnn_nwc; + const format_tag_t nchw = mkldnn_nchw; + const format_tag_t nhwc = mkldnn_nhwc; + const format_tag_t chwn = mkldnn_chwn; + const format_tag_t ncdhw = mkldnn_ncdhw; + const format_tag_t ndhwc = mkldnn_ndhwc; + const format_tag_t oi = mkldnn_oi; + const format_tag_t io = mkldnn_io; + const format_tag_t oiw = mkldnn_oiw; + const format_tag_t wio = mkldnn_wio; + const format_tag_t oihw = mkldnn_oihw; + const format_tag_t hwio = mkldnn_hwio; + const format_tag_t ihwo = mkldnn_ihwo; + const format_tag_t iohw = mkldnn_iohw; + const format_tag_t oidhw = mkldnn_oidhw; + const format_tag_t dhwio = mkldnn_dhwio; + const format_tag_t goiw = mkldnn_goiw; + const format_tag_t goihw = mkldnn_goihw; + const format_tag_t hwigo = mkldnn_hwigo; + const format_tag_t giohw = mkldnn_giohw; + const format_tag_t goidhw = mkldnn_goidhw; + const format_tag_t tnc = mkldnn_tnc; + const format_tag_t ntc = mkldnn_ntc; + const format_tag_t ldsnc = mkldnn_ldsnc; + const format_tag_t ldigo = mkldnn_ldigo; + const format_tag_t ldgoi = mkldnn_ldgoi; + const format_tag_t ldgo = mkldnn_ldgo; + const format_tag_t nCdhw16c = mkldnn_nCdhw16c; + const format_tag_t nCdhw4c = mkldnn_nCdhw4c; + const format_tag_t nCdhw8c = mkldnn_nCdhw8c; + const format_tag_t nChw16c = mkldnn_nChw16c; + const format_tag_t nChw4c = mkldnn_nChw4c; + const format_tag_t nChw8c = mkldnn_nChw8c; + const format_tag_t nCw16c = mkldnn_nCw16c; + const format_tag_t nCw4c = mkldnn_nCw4c; + const format_tag_t nCw8c = mkldnn_nCw8c; + const format_tag_t IOw16o16i = mkldnn_IOw16o16i; + const format_tag_t OIw16i16o = mkldnn_OIw16i16o; + const format_tag_t OIw16o16i = mkldnn_OIw16o16i; + const format_tag_t Oiw16o = mkldnn_Oiw16o; + const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i; + const format_tag_t OIw4i4o = mkldnn_OIw4i4o; + const format_tag_t Oiw4o = mkldnn_Oiw4o; + const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i; + const format_tag_t OIw8i8o = mkldnn_OIw8i8o; + const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o; + const format_tag_t OIw8o8i = mkldnn_OIw8o8i; + const format_tag_t Owi16o = mkldnn_Owi16o; + const format_tag_t Owi4o = mkldnn_Owi4o; + const format_tag_t Owi8o = mkldnn_Owi8o; + const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i; + const format_tag_t Ohwi16o = mkldnn_Ohwi16o; + const format_tag_t Ohwi4o = mkldnn_Ohwi4o; + const format_tag_t Ohwi8o = mkldnn_Ohwi8o; + const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o; + const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i; + const format_tag_t Oihw16o = mkldnn_Oihw16o; + const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i; + const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o; + const format_tag_t Oihw4o = mkldnn_Oihw4o; + const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i; + const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o; + const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o; + const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i; + const format_tag_t Odhwi16o = mkldnn_Odhwi16o; + const format_tag_t Odhwi4o = mkldnn_Odhwi4o; + const format_tag_t Odhwi8o = mkldnn_Odhwi8o; + const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o; + const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i; + const format_tag_t Oidhw16o = mkldnn_Oidhw16o; + const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o; + const format_tag_t Oidhw4o = mkldnn_Oidhw4o; + const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i; + const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o; + const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i; + const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i; + const format_tag_t Goiw16g = mkldnn_Goiw16g; + const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o; + const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i; + const format_tag_t gOiw16o = mkldnn_gOiw16o; + const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i; + const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o; + const format_tag_t gOiw4o = mkldnn_gOiw4o; + const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i; + const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o; + const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; + const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i; + const format_tag_t gOwi16o = mkldnn_gOwi16o; + const format_tag_t gOwi4o = mkldnn_gOwi4o; + const format_tag_t gOwi8o = mkldnn_gOwi8o; + const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i; + const format_tag_t gOhwi16o = mkldnn_gOhwi16o; + const format_tag_t gOhwi4o = mkldnn_gOhwi4o; + const format_tag_t gOhwi8o = mkldnn_gOhwi8o; + const format_tag_t Goihw16g = mkldnn_Goihw16g; + const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o; + const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i; + const format_tag_t gOihw16o = mkldnn_gOihw16o; + const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i; + const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i; + const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o; + const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i; + const format_tag_t gOihw4o = mkldnn_gOihw4o; + const format_tag_t Goihw8g = mkldnn_Goihw8g; + const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i; + const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o; + const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o; + const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i; + const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o; + const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o; + const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o; + const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o; + const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i; + const format_tag_t gOidhw16o = mkldnn_gOidhw16o; + const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o; + const format_tag_t gOidhw4o = mkldnn_gOidhw4o; + const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i; + const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o; + const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i; +} + +using memory_extra_flags_t = mkldnn_memory_extra_flags_t; +namespace memory_extra_flags { + const memory_extra_flags_t none = mkldnn_memory_extra_flag_none; + const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8; + const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust; +} + +using padding_kind_t = mkldnn_padding_kind_t; +namespace padding_kind { + const padding_kind_t padding_zero = mkldnn_padding_zero; +} + +using engine_kind_t = mkldnn_engine_kind_t; +namespace engine_kind { + const engine_kind_t any_engine = mkldnn_any_engine; + const engine_kind_t cpu = mkldnn_cpu; +} + +using primitive_kind_t = mkldnn_primitive_kind_t; +namespace primitive_kind { + const primitive_kind_t undefined = mkldnn_undefined_primitive; + const primitive_kind_t reorder = mkldnn_reorder; + const primitive_kind_t concat = mkldnn_concat; + const primitive_kind_t sum = mkldnn_sum; + const primitive_kind_t convolution = mkldnn_convolution; + const primitive_kind_t deconvolution = mkldnn_deconvolution; + const primitive_kind_t shuffle = mkldnn_shuffle; + const primitive_kind_t eltwise = mkldnn_eltwise; + const primitive_kind_t softmax = mkldnn_softmax; + const primitive_kind_t pooling = mkldnn_pooling; + const primitive_kind_t lrn = mkldnn_lrn; + const primitive_kind_t batch_normalization = mkldnn_batch_normalization; + const primitive_kind_t inner_product = mkldnn_inner_product; + const primitive_kind_t rnn = mkldnn_rnn; +} + +using query_t = mkldnn_query_t; +namespace query { + const query_t undef = mkldnn_query_undef; + + const query_t engine = mkldnn_query_engine; + const query_t primitive_kind = mkldnn_query_primitive_kind; + + const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32; + const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32; + + const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64; + const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64; + + const query_t scratchpad_engine = mkldnn_query_scratchpad_engine; + + const query_t impl_info_str = mkldnn_query_impl_info_str; + + const query_t some_d = mkldnn_query_some_d; + const query_t op_d = mkldnn_query_op_d; + const query_t convolution_d = mkldnn_query_convolution_d; + const query_t deconvolution_d = mkldnn_query_deconvolution_d; + const query_t shuffle_d = mkldnn_query_shuffle_d; + const query_t eltwise_d = mkldnn_query_eltwise_d; + const query_t softmax_d = mkldnn_query_softmax_d; + const query_t pooling_d = mkldnn_query_pooling_d; + const query_t lrn_d = mkldnn_query_lrn_d; + const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; + const query_t inner_product_d = mkldnn_query_inner_product_d; + const query_t rnn_d = mkldnn_query_rnn_d; + + const query_t some_md = mkldnn_query_some_md; + const query_t src_md = mkldnn_query_src_md; + const query_t diff_src_md = mkldnn_query_diff_src_md; + const query_t weights_md = mkldnn_query_weights_md; + const query_t diff_weights_md = mkldnn_query_diff_weights_md; + const query_t dst_md = mkldnn_query_dst_md; + const query_t diff_dst_md = mkldnn_query_diff_dst_md; + + const query_t workspace_md = mkldnn_query_workspace_md; + const query_t scratchpad_md = mkldnn_query_scratchpad_md; +} + +using blocking_desc_t = mkldnn_blocking_desc_t; +using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t; +using wino_desc_t = mkldnn_wino_desc_t; +using memory_extra_desc_t = mkldnn_memory_extra_desc_t; +using memory_desc_t = mkldnn_memory_desc_t; +using convolution_desc_t = mkldnn_convolution_desc_t; +using deconvolution_desc_t = mkldnn_deconvolution_desc_t; +using shuffle_desc_t = mkldnn_shuffle_desc_t; +using pooling_desc_t = mkldnn_pooling_desc_t; +using eltwise_desc_t = mkldnn_eltwise_desc_t; +using softmax_desc_t = mkldnn_softmax_desc_t; +using lrn_desc_t = mkldnn_lrn_desc_t; +using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; +using inner_product_desc_t = mkldnn_inner_product_desc_t; + +using rnn_direction_t = mkldnn_rnn_direction_t; +using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; +using rnn_desc_t = mkldnn_rnn_desc_t; + +/* C op_desc_t, which eventually are just (void*) */ +using c_op_desc_t = mkldnn_op_desc_t; +using const_c_op_desc_t = const_mkldnn_op_desc_t; + +struct op_desc_t { + union { + primitive_kind_t kind; + convolution_desc_t convolution; + deconvolution_desc_t deconvolution; + shuffle_desc_t shuffle; + pooling_desc_t pooling; + eltwise_desc_t eltwise; + softmax_desc_t softmax; + lrn_desc_t lrn; + batch_normalization_desc_t batch_normalization; + inner_product_desc_t inner_product; + rnn_desc_t rnn; + }; + + op_desc_t(const primitive_kind_t &_): kind(_) {} + +# define DECL_CTOR_AND_CONVERTERS(c_type, name) \ + op_desc_t(const c_type &_): name(_) {} \ + static op_desc_t *convert_from_c(c_type *_) \ + { return reinterpret_cast(_); } \ + static const op_desc_t *convert_from_c(const c_type *_) \ + { return reinterpret_cast(_); } + + DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution); + DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle); + DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling); + DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise); + DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax); + DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn); + DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization); + DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product); + DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn); + +# undef DECL_CTOR_AND_CONVERTERS +}; + +using engine_t = mkldnn_engine; +using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator; +using primitive_desc_t = mkldnn_primitive_desc; +using primitive_attr_t = mkldnn_primitive_attr; +using post_ops_t = mkldnn_post_ops; +using memory_t = mkldnn_memory; +using primitive_t = mkldnn_primitive; + +using primitive_arg_index_t = int; + +using stream_flags_t = mkldnn_stream_flags_t; +namespace stream_flags { + const stream_flags_t default_flags = mkldnn_stream_default_flags; +} +using stream_t = mkldnn_stream; + +/* forward declaration of the internal primitive_desc types */ +struct batch_normalization_bwd_pd_t; +struct batch_normalization_fwd_pd_t; +struct batch_normalization_pd_t; +struct concat_pd_t; +struct convolution_bwd_data_pd_t; +struct convolution_bwd_weights_pd_t; +struct convolution_fwd_pd_t; +struct convolution_pd_t; +struct deconvolution_bwd_data_pd_t; +struct deconvolution_bwd_weights_pd_t; +struct deconvolution_fwd_pd_t; +struct deconvolution_pd_t; +struct eltwise_bwd_pd_t; +struct eltwise_fwd_pd_t; +struct eltwise_pd_t; +struct inner_product_bwd_data_pd_t; +struct inner_product_bwd_weights_pd_t; +struct inner_product_fwd_pd_t; +struct inner_product_pd_t; +struct lrn_bwd_pd_t; +struct lrn_fwd_pd_t; +struct lrn_pd_t; +struct pooling_bwd_pd_t; +struct pooling_fwd_pd_t; +struct pooling_pd_t; +struct reorder_pd_t; +struct rnn_bwd_pd_t; +struct rnn_fwd_pd_t; +struct rnn_pd_t; +struct shuffle_pd_t; +struct softmax_bwd_pd_t; +struct softmax_fwd_pd_t; +struct softmax_pd_t; +struct sum_pd_t; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp new file mode 100644 index 0000000000..ed4c35c6e9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "concat_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds, + const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(concat_pd, src_mds) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + int concat_dim_sz = dims[concat_dim]; + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (d == concat_dim) continue; + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + concat_dim_sz += src_mds[i].dims[concat_dim]; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != + (d == concat_dim ? concat_dim_sz : dims[d])) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.dims[concat_dim] = concat_dim_sz; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto c_pd = reinterpret_cast(concat_pd); + + for (auto c = engine->get_concat_implementation_list(); *c; ++c) { + if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds) + == success) { + (*c_pd)->init_info(); + (*c_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp new file mode 100644 index 0000000000..29311927e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp @@ -0,0 +1,211 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONCAT_PD_HPP +#define CONCAT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct concat_pd_t: public primitive_desc_t { + concat_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::concat) + , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md) + { + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + concat_pd_t(const concat_pd_t &rhs) = default; + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + int concat_dim() const { return concat_dim_; } + + const memory_desc_t *src_image_md(int index = 0) const + { return index < n_inputs() ? &src_image_mds_[index] : nullptr; } + +protected: + int n_, concat_dim_; + memory_desc_t dst_md_; + nstl::vector src_mds_; + + /* contains images of srcs in the dst memory (if possible) + * Lives here to simplify some implementations. An implementation might + * use this auxiliary array iff init() returned success */ + nstl::vector src_image_mds_; + +protected: + /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */ + status_t init() { + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + if (!i_d.is_blocking_desc() || i_d.is_additional_buffer()) + return status::unimplemented; + } + + const int ndims = dst_md_.ndims; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) return status; + src_image_mds_.push_back(src_img_d); + current_concat_dim_offset += dim; + } + + return status::success; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + const int ndims = dst_md_.ndims; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain or it is not possible to create a + * blocked format for the output, pick the format of the plain input + * - If this fails as well, use plain layout (abcd...) + */ + status_t status = status::unimplemented; + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && !src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + if (status == status::success) break; + } + } + + if (status == status::success) { + /* check if we can create a sub-memory for the dst */ + bool desired_format_ok = true; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) { + desired_format_ok = false; + break; + } + current_concat_dim_offset += dim; + } + + if (!desired_format_ok) + status = status::unimplemented; + } + + /* if no success so far, try using the format of the first plain input */ + if (status != status::success) { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + memory_desc_wrapper(src_mds_[0]).blocking_desc()); + if (status == status::success) return status; + } + } + } + + /* the last line of defense: use plain abcd... format */ + if (status != status::success) + status = memory_desc_init_by_strides(dst_md_, nullptr); + + return status; + } +}; + +#define DECLARE_CONCAT_PD_t(impl_name, ...) \ + static status_t create(concat_pd_t **concat_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, int concat_dim, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign(*concat_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_CONCAT_PD_T(impl_name, ...) \ + DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp new file mode 100644 index 0000000000..0c5c02bcd1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto cd = convolution_desc_t(); + cd.primitive_kind = primitive_kind::convolution; + cd.prop_kind = prop_kind; + cd.alg_kind = alg_kind; + + cd.diff_src_desc = cd.src_desc = zero_md(); + cd.diff_dst_desc = cd.dst_desc = zero_md(); + cd.diff_weights_desc = cd.weights_desc = zero_md(); + cd.diff_bias_desc = cd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc; + (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) = + *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(cd.strides, strides, sp_dims); + utils::array_copy(cd.padding[0], padding_l, sp_dims); + utils::array_copy(cd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(cd.dilates, dilates, sp_dims); + else + utils::array_set(cd.dilates, 0, sp_dims); + + cd.padding_kind = padding_kind; + cd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + const int bias_dim = prop_kind == backward_data + ? src_desc->dims[1] + : dst_desc->dims[1]; + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == bias_dim : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) + { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = cd.dilates[i - 2]; + int pad_l = padding_l[i - 2]; + int pad_r = padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + if (str < 1) return invalid_arguments; + consistency = consistency + && dil >= 0 + && pad_l >= 0 + && pad_r + str > 0 + && (src - ker_range + pad_l + pad_r) / str + 1 == dst; + } + if (!consistency) return invalid_arguments; + + *conv_desc = cd; + return success; +} +} +} + +status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_forward_desc_init( + convolution_desc_t *conv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + nullptr, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + dilates, padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp new file mode 100644 index 0000000000..9604e0acf5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "convolution_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_src_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_wei_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_bia_d(const_cast(desc)); } +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_dst_d(const_cast(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp new file mode 100644 index 0000000000..b10c36db49 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONVOLUTION_PD_HPP +#define CONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind); + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); + +struct convolution_fwd_pd_t; + +struct convolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::convolution; + + convolution_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const convolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits::query_d: + *(const convolution_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common conv aux functions */ + + dim_t MB() const { return _src_md()->dims[0]; } + + dim_t IC() const { return _src_md()->dims[1]; } + dim_t OC() const { return _dst_md()->dims[1]; } + dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } + + dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } + dim_t IW() const { return _src_md()->dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } + dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } + dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } + dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return _src_md()->ndims; } + + bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } + bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*_src_md()); + const auto d_d = memory_desc_wrapper(*_dst_md()); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + convolution_desc_t desc_; + const convolution_fwd_pd_t *hint_fwd_pd_; + + bool set_default_formats_common_template( + memory_desc_t &src_md, format_tag_t src_tag, + memory_desc_t &wei_md, format_tag_t wei_tag, + memory_desc_t &dst_md, format_tag_t dst_tag, + memory_desc_t &bia_md) { + using namespace format_tag; + +# define IS_OK(f) \ + do { if ((f) != status::success) return false; } while(0) + if (src_md.format_kind == format_kind::any + && !utils::one_of(src_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(src_md, src_tag)); + if (dst_md.format_kind == format_kind::any + && !utils::one_of(dst_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); + if (wei_md.format_kind == format_kind::any + && !utils::one_of(wei_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); + if (with_bias() && bia_md.format_kind == format_kind::any) + IS_OK(memory_desc_init_by_tag(bia_md, x)); +# undef IS_OK + + return true; + } + + bool set_default_alg_kind(alg_kind_t alg_kind) { + assert(utils::one_of(alg_kind, alg_kind::convolution_direct, + alg_kind::convolution_winograd)); + if (desc_.alg_kind == alg_kind::convolution_auto) + desc_.alg_kind = alg_kind; + return desc_.alg_kind == alg_kind; + } + + bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, + data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { + bool ok = true + && (src_dt == data_type::undef || _src_md()->data_type == src_dt) + && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) + && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) + && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); + if (with_bias() && bia_dt != data_type::undef) + ok = ok && _bia_md()->data_type == bia_dt; + return ok; + } + +private: + const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } + const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } + const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } + const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } +}; + +struct convolution_fwd_pd_t: public convolution_pd_t { + typedef convolution_fwd_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_fwd_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t wei_tag, format_tag_t dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); + } +}; + +struct convolution_bwd_data_pd_t: public convolution_pd_t { + typedef convolution_bwd_data_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_data_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + + virtual bool support_bias() const { return false; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t diff_src_tag, + format_tag_t wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(diff_src_md_, diff_src_tag, + weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); + } +}; + +struct convolution_bwd_weights_pd_t: public convolution_pd_t { + typedef convolution_bwd_weights_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_weights_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, + diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp new file mode 100644 index 0000000000..98063c1c37 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp @@ -0,0 +1,188 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, deconvolution_direct, deconvolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) + return invalid_arguments; + + if (padding_r == nullptr) + padding_r = padding_l; + + auto dd = deconvolution_desc_t(); + dd.primitive_kind = primitive_kind::deconvolution; + dd.prop_kind = prop_kind; + dd.alg_kind = alg_kind; + + dd.diff_src_desc = dd.src_desc = zero_md(); + dd.diff_dst_desc = dd.dst_desc = zero_md(); + dd.diff_weights_desc = dd.weights_desc = zero_md(); + dd.diff_bias_desc = dd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias + = bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc; + (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc) + = *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc) + = *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(dd.strides, strides, sp_dims); + utils::array_copy(dd.padding[0], padding_l, sp_dims); + utils::array_copy(dd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(dd.dilates, dilates, sp_dims); + else + utils::array_set(dd.dilates, 0, sp_dims); + + dd.padding_kind = padding_kind; + dd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + bool consistency = true + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = dd.dilates[i - 2]; + int pad = padding_l[i - 2] + padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + consistency + = consistency && (dst - ker_range + pad) / str + 1 == src; + } + if (!consistency) + return invalid_arguments; + + *deconv_desc = dd; + return success; +} +} + +status_t mkldnn_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp new file mode 100644 index 0000000000..539e44bd9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp @@ -0,0 +1,293 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DECONVOLUTION_PD_HPP +#define DECONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct deconvolution_fwd_pd_t; + +struct deconvolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::deconvolution; + + deconvolution_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const deconvolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits::query_d: + *(const deconvolution_desc_t **)result = desc(); + break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ + + dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; } + + dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; } + dim_t G() const + { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; } + + dim_t ID() const { + return ndims() >= 5 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1]; + } + + dim_t OD() const { + return ndims() >= 5 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1]; + } + + dim_t KD() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 5 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1; + } + dim_t KH() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 4 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1; + } + dim_t KW() const { + const int w_ndims = ndims() + with_groups(); + return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1]; + } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + bool with_bias() const { + return + !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero(); + } + + bool with_groups() const + { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; } + + int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + deconvolution_desc_t desc_; + const deconvolution_fwd_pd_t *hint_fwd_pd_; +}; + +struct deconvolution_fwd_pd_t: public deconvolution_pd_t { + typedef deconvolution_fwd_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_fwd_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; +}; + +struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_data_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_data_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; +}; + +struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_weights_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_weights_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp new file mode 100644 index 0000000000..f1708fca52 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float alpha, float beta) { + bool args_ok = true + && !any_null(eltwise_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data) + && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ed = eltwise_desc_t(); + ed.primitive_kind = primitive_kind::eltwise; + ed.prop_kind = prop_kind; + ed.alg_kind = alg_kind; + + ed.data_desc = *data_desc; + ed.diff_data_desc = + (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md(); + + ed.alpha = alpha; + ed.beta = beta; + + bool consistency = true + && IMPLICATION(ed.prop_kind == backward_data, + array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims, + ed.diff_data_desc.ndims)); + if (!consistency) return invalid_arguments; + + *eltwise_desc = ed; + return success; +} +} + +status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, float alpha, float beta) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc, + nullptr, alpha, beta); +} + +status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_data_desc, + const memory_desc_t *data_desc, float alpha, float beta) { + return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc, + diff_data_desc, alpha, beta); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp new file mode 100644 index 0000000000..9fd260fcee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ELTWISE_PD_HPP +#define ELTWISE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct eltwise_fwd_pd_t; + +struct eltwise_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::eltwise; + + eltwise_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const eltwise_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::eltwise_d: + *(const eltwise_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common eltwise aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + eltwise_desc_t desc_; + const eltwise_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct eltwise_fwd_pd_t: public eltwise_pd_t { + typedef eltwise_fwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const + { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); } +}; + +struct eltwise_bwd_pd_t: public eltwise_pd_t { + typedef eltwise_bwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_bwd_pd_t(engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const { return true; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp new file mode 100644 index 0000000000..3b3e25456d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include "engine.hpp" +#include "nstl.hpp" + +#include "c_types_map.hpp" +#include "../cpu/cpu_engine.hpp" + +namespace mkldnn { +namespace impl { + +engine_factory_t *engine_factories[] = { + &cpu::engine_factory, + nullptr, +}; + +static inline engine_factory_t *get_engine_factory(engine_kind_t kind) { + for (engine_factory_t **ef = engine_factories; *ef; ef++) + if ((*ef)->kind() == kind) + return *ef; + return nullptr; +} + +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +size_t mkldnn_engine_get_count(engine_kind_t kind) { + engine_factory_t *ef = get_engine_factory(kind); + return ef != nullptr ? ef->count() : 0; +} + +status_t mkldnn_engine_create(engine_t **engine, + engine_kind_t kind, size_t index) { + if (engine == nullptr) + return invalid_arguments; + + engine_factory_t *ef = get_engine_factory(kind); + if (ef == nullptr || index >= ef->count()) + return invalid_arguments; + + return ef->engine_create(engine, index); +} + +status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) { + if (engine == nullptr) + return invalid_arguments; + *kind = engine->kind(); + return success; +} + +status_t mkldnn_engine_destroy(engine_t *engine) { + /* TODO: engine->dec_ref_count(); */ + delete engine; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp new file mode 100644 index 0000000000..8ac8a29de5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ENGINE_HPP +#define ENGINE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive.hpp" +#include "utils.hpp" + +/** \brief An abstraction of an execution unit with shared resources + * + * Responsibilities: + * - Provide engine specific memory allocation + * - Provide engine specific primitive_desc_t creators + */ +struct mkldnn_engine: public mkldnn::impl::c_compatible { + mkldnn_engine(mkldnn::impl::engine_kind_t kind) + : kind_(kind) + {} + virtual ~mkldnn_engine() {} + + /** get kind of the current engine */ + virtual mkldnn::impl::engine_kind_t kind() const { return kind_; } + + /** allocate memory */ + virtual mkldnn::impl::status_t memory_create( + mkldnn::impl::memory_t **memory, + const mkldnn::impl::memory_desc_t *md, + void *handle) = 0; + + /** implementation section (typedefs) */ + + // TODO: remove engine? + typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)( + mkldnn::impl::reorder_pd_t **reorder_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *src_engine, + const mkldnn::impl::memory_desc_t *src_md, + mkldnn::impl::engine_t *dst_engine, + const mkldnn::impl::memory_desc_t *dst_md); + + typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)( + mkldnn::impl::concat_pd_t **concat_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, int concat_dim, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)( + mkldnn::impl::sum_pd_t **sum_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, const float *scales, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*primitive_desc_create_f)( + mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *); + + /* implementation section */ + + /** return the list of reorder implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const = 0; + + /** return the list of concat implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const = 0; + + /** return the list of sum implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const = 0; + + /** return the list of implementations. engine guarantees to return a + * NULL-terminated list */ + virtual const primitive_desc_create_f* get_implementation_list() const = 0; + +protected: + mkldnn::impl::engine_kind_t kind_; +}; + +namespace mkldnn { +namespace impl { + +struct engine_factory_t: public c_compatible { + virtual size_t count() const = 0; + virtual engine_kind_t kind() const = 0; + virtual status_t engine_create(engine_t **engine, size_t index) const = 0; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp new file mode 100644 index 0000000000..5a9f58cb1e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc); + if (!args_ok) return invalid_arguments; + + auto id = inner_product_desc_t(); + id.primitive_kind = primitive_kind::inner_product; + id.prop_kind = prop_kind; + + id.diff_src_desc = id.src_desc = zero_md(); + id.diff_dst_desc = id.dst_desc = zero_md(); + id.diff_weights_desc = id.weights_desc = zero_md(); + id.diff_bias_desc = id.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + + (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc; + (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) = + *bias_desc; + + id.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && one_of(src_desc->ndims, 2, 3, 4, 5) + && dst_desc->ndims == 2 + && weights_desc->ndims == src_desc->ndims + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && array_cmp(&src_desc->dims[1], &weights_desc->dims[1], + src_desc->ndims - 1) + && dst_desc->dims[1] == weights_desc->dims[0]; + if (!consistency) return invalid_arguments; + + *ip_desc = id; + return success; +} +} + +status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc, + prop_kind_t prop_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc, + dst_desc); +} + +status_t mkldnn_inner_product_backward_data_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc) +{ + return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc, + nullptr, diff_dst_desc); +} + +status_t mkldnn_inner_product_backward_weights_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *src_desc, + const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc) { + return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc, + diff_bias_desc, diff_dst_desc); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp new file mode 100644 index 0000000000..091cf0f5d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "inner_product_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_src_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_wei_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_bia_d(const_cast(desc)); } +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_dst_d(const_cast(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp new file mode 100644 index 0000000000..c426de632c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp @@ -0,0 +1,321 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef INNER_PRODUCT_PD_HPP +#define INNER_PRODUCT_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); + +struct inner_product_fwd_pd_t; + +struct inner_product_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::inner_product; + + inner_product_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const inner_product_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::inner_product_d: + *(const inner_product_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common inner_product aux functions */ + + dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } + dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } + + dim_t ID() const { + return ndims() >= 5 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return ndims() >= 3 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t OD() const { + return ndims() >= 5 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return ndims() >= 3 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t KD() const { + return ndims() >= 5 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t KH() const { + return ndims() >= 4 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t KW() const { + return ndims() >= 3 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t IC_total() const { + return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], + ndims() - 1); + } + + dim_t IC_total_padded() const { + auto src_d = desc()->prop_kind == prop_kind::backward_data + ? memory_desc_wrapper(diff_src_md()) + : memory_desc_wrapper(src_md()); + assert(src_d.is_blocking_desc()); + if (!src_d.is_blocking_desc()) return -1; + return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); + } + + int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } + + bool with_bias() const + { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + inner_product_desc_t desc_; + const inner_product_fwd_pd_t *hint_fwd_pd_; + + status_t template_set_default_params(memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t *bias_md) { + using namespace format_tag; + if (src_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, + utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); + } + if (dst_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md, nc)); + if (weights_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, + utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); + } + if (bias_md && bias_md->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(*bias_md, x)); + return status::success; + } +}; + +struct inner_product_fwd_pd_t: public inner_product_pd_t { + typedef inner_product_fwd_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_fwd_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, weights_md_, dst_md_, + &bias_md_); + } +}; + +struct inner_product_bwd_data_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_data_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_data_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(diff_src_md_, weights_md_, + diff_dst_md_, nullptr); + } +}; + +struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_weights_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_weights_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, diff_weights_md_, + diff_dst_md_, &diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp new file mode 100644 index 0000000000..fcf18b556f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t lrn_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc, + dim_t local_size, float alpha, float beta, float k) { + bool args_ok = true + && !any_null(lrn_desc, data_desc) + && one_of(alg_kind, lrn_within_channel, lrn_across_channels) + && one_of(prop_kind, forward_training, forward_inference, backward_data) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ld = lrn_desc_t(); + ld.primitive_kind = primitive_kind::lrn; + ld.prop_kind = prop_kind; + ld.alg_kind = alg_kind; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + ld.data_desc = *data_desc; + if (!is_fwd) + ld.diff_data_desc = *diff_data_desc; + else + ld.diff_data_desc = zero_md(); + ld.local_size = local_size; + ld.lrn_alpha = alpha; + ld.lrn_beta = beta; + ld.lrn_k = k; + + bool consistency = true + && ld.data_desc.ndims == 4; + if (ld.prop_kind == backward_data) + consistency = consistency + && ld.diff_data_desc.ndims == 4 + && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4); + if (!consistency) return invalid_arguments; + + *lrn_desc = ld; + return success; +} +} + +status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, dim_t local_size, float alpha, + float beta, float k) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr, + local_size, alpha, beta, k); +} + +status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, dim_t local_size, float alpha, + float beta, float k) { + return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc, + diff_data_desc, local_size, alpha, beta, k); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp new file mode 100644 index 0000000000..90886e9656 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef LRN_PD_HPP +#define LRN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct lrn_fwd_pd_t; + +struct lrn_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::lrn; + + lrn_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , ws_md_() + {} + + const lrn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::lrn_d: + *(const lrn_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common lrn aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + lrn_desc_t desc_; + const lrn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t ws_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct lrn_fwd_pd_t: public lrn_pd_t { + typedef lrn_fwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_fwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct lrn_bwd_pd_t: public lrn_pd_t { + typedef lrn_bwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_bwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp new file mode 100644 index 0000000000..3fddc0bd45 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MATH_UTILS_HPP +#define MATH_UTILS_HPP + +#include +#include + +#include "utils.hpp" +#include "nstl.hpp" +#include "mkldnn_traits.hpp" + +#if defined(MKLDNN_X86_64) +#include "immintrin.h" +#endif + +namespace mkldnn { +namespace impl { +namespace math { + +/** rounds @p f to an integer according to the mxcsr register */ +inline int mxcsr_round(float f) { +#if defined(MKLDNN_X86_64) + return _mm_cvtss_si32(_mm_load_ss(&f)); +#else + return (int)nearbyintf(f); // optimism +#endif +} + +template +inline typename utils::enable_if::value, + typename utils::remove_reference::type>::type +saturate(const acc_t &x) { + return (typename utils::remove_reference::type)x; +} + +template +inline typename utils::enable_if::value, + typename utils::remove_reference::type>::type +saturate(const acc_t &x) { + acc_t v = x; + if (v < (acc_t)nstl::numeric_limits::lowest()) + v = (acc_t)nstl::numeric_limits::lowest(); + if (v > (acc_t)nstl::numeric_limits::max()) + v = (acc_t)nstl::numeric_limits::max(); + return (typename utils::remove_reference::type)v; +} + +template +double saturate(const double &x) { + double v = x; + if (v < (double)nstl::numeric_limits::lowest()) + v = (double)nstl::numeric_limits::lowest(); + if (v > (double)nstl::numeric_limits::max()) + v = (double)nstl::numeric_limits::max(); + return v; +} + +template <> inline int8_t saturate(const uint8_t &x) { + return x <= 127u ? x : 127; +} + +template <> inline uint8_t saturate(const int8_t &x) { + return x >= 0 ? x : 0; +} + +template +typename utils::enable_if::value, out_t>::type +out_round(float v) { return (out_t)mxcsr_round(v); } + +template +typename utils::enable_if::value, out_t>::type +out_round(double v) { return (out_t)mxcsr_round((float)v); } + +template +typename utils::enable_if::value, out_t>::type +out_round(float v) { return v; } + +inline int gcd(int a, int b) { + a = impl::nstl::abs(a); + b = impl::nstl::abs(b); + if (a < b) { int x = a; a = b; b = x; } + + if (b == 0) return a; + + int r; + while ((r = a % b) != 0) { a = b; b = r; } + + return b; +} + +template +inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } + +/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ +inline int ilog2q(size_t v) { + if (v == 0) + return -1; + + int p = 0; +# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) + CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); +# undef CP + return p; +} + +template ::type> +inline U one_m_square(T x) { + return (U)(1 - x) * (1 + x); +} + +template ::type> +inline U x_m_square(T x) { + return (U)(1 - x) * x; +} + +/* activation */ +template ::type> +inline U relu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(s * alpha); +} +template ::type> +inline U relu_bwd(T dd, T s, A alpha) { + return s > 0 ? dd : (U)(dd * alpha); +} + +template ::type> +inline U tanh_fwd(T s) { + const float e = tanhf((float) s); + return (U)e; +} + +template ::type> +inline U tanh_bwd(T dd, T s) { + const float e = tanh_fwd((float) s); + return (U)(dd * (1 - e) * (1 + e)); +} + +template ::type> +inline U elu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); +} +template ::type> + inline U elu_bwd(T dd, T s, A alpha) { + return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); +} + +template ::type> +inline U square_fwd(T s) { + return s * s; +} + +template ::type> +inline U square_bwd(T dd, T s) { + return dd * 2 * s; +} + +template ::type> +inline U abs_fwd(T s) { + return s > 0 ? s : -s; +} + +template ::type> +inline U abs_bwd(T dd, T s) { + return s > 0 ? dd : s < 0 ? -dd : 0; +} + +template ::type> +inline U sqrt_fwd(T s) { + return s > 0 ? (U)(::sqrtf((float)(s))) : 0; +} + +template ::type> +inline U sqrt_bwd(T dd, T s) { + return s > 0 + ? (U)(dd / (2 * ::sqrtf((float)(s)))) + : 0; +} + +template ::type> +inline U linear_fwd(T s, A alpha, A beta) { + return (U)(alpha * s + beta); +} + +template ::type> +inline U linear_bwd(T dd, T s, A alpha, A beta) { + (void) s; + (void) beta; + return (U)(dd * alpha); +} + +template ::type> +inline U bounded_relu_fwd(T s, A alpha) { + s = s > 0 ? s : 0; + return s > alpha ? (U)(alpha) : s; +} + +template ::type> +inline U bounded_relu_bwd(T dd, T s, A alpha) { + return dd * (0 < s && s < alpha ? 1 : 0); +} + +template ::type> +inline U soft_relu_fwd(T s) { + float max_logf = 8.872284e+01; //::logf(FLT_MAX) + return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; +} + +template ::type> +inline U soft_relu_bwd(T dd, T s) { + return (U)(dd / (1 + ::expf((float)(-s)))); +} + +template ::type> +inline U logistic_fwd(T s) { + U v = (U)(::expf((float) -s)); + return 1 / (1 + v); +} + +template ::type> +inline U logistic_bwd(T dd, T s) { + U v = logistic_fwd(s); + return dd * v * (1 - v); +} + +inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { + using namespace alg_kind; + using namespace utils; + const bool preserves_zero = true + && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); + return preserves_zero; +} + +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) +{ + if (!bias) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits
::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp new file mode 100644 index 0000000000..cea849c96e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::data_type; + +namespace { +bool memory_desc_sanity_check(int ndims,const dims_t dims, + data_type_t data_type, format_kind_t format_kind) { + if (ndims == 0) return true; + + bool ok = true + && dims != nullptr + && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS + && one_of(data_type, f32, s32, s8, u8) + && format_kind != format_kind::undef; + if (!ok) return false; + for (int d = 0; d < ndims; ++d) + if (dims[d] < 0) return false; + + return true; +} + +bool memory_desc_sanity_check(const memory_desc_t *md) { + if (md == nullptr) return false; + return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, + format_kind::any); +} +} + +status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, + const dims_t dims, data_type_t data_type, format_tag_t tag) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0 || tag == format_tag::undef) { + *memory_desc = types::zero_md(); + return success; + } + + format_kind_t format_kind = types::format_tag_to_kind(tag); + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind; + + status_t status = success; + if (tag == format_tag::undef) { + status = invalid_arguments; + } else if (tag == format_tag::any) { + // nop + } else if (format_kind == format_kind::blocked) { + status = memory_desc_wrapper::compute_blocking(md, tag); + } else { + assert(!"unreachable"); + status = invalid_arguments; + } + + if (status == success) + *memory_desc = md; + + return status; +} + +status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, + int ndims, const dims_t dims, data_type_t data_type, + const dims_t strides) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0) { + *memory_desc = types::zero_md(); + return success; + } + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind::blocked; + + dims_t default_strides = {0}; + if (strides == nullptr) { + default_strides[md.ndims - 1] = 1; + for (int d = md.ndims - 2; d >= 0; --d) + default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; + strides = default_strides; + } else { + /* TODO: add sanity check for the provided strides */ + } + + array_copy(md.format_desc.blocking.strides, strides, md.ndims); + + *memory_desc = md; + + return status::success; +} + +status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, + const memory_desc_t *parent_md, const dims_t dims, + const dims_t offsets) { + if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) + return invalid_arguments; + + const memory_desc_wrapper src_d(parent_md); + + for (int d = 0; d < src_d.ndims(); ++d) { + if (dims[d] < 0 || offsets[d] < 0 + || (offsets[d] + dims[d] > src_d.dims()[d])) + return invalid_arguments; + } + + if (src_d.format_kind() != format_kind::blocked) + return unimplemented; + + dims_t blocks; + src_d.compute_blocks(blocks); + + memory_desc_t dst_d = *parent_md; + auto &dst_d_blk = dst_d.format_desc.blocking; + + /* TODO: put this into memory_desc_wrapper */ + for (int d = 0; d < src_d.ndims(); ++d) { + /* very limited functionality for now */ + const bool ok = true + && offsets[d] % blocks[d] == 0 /* [r1] */ + && src_d.padded_offsets()[d] == 0 + && (false + || dims[d] % blocks[d] == 0 + || dims[d] < blocks[d]); + if (!ok) + return unimplemented; + + const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; + + dst_d.dims[d] = dims[d]; + dst_d.padded_dims[d] = is_right_border + ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; + dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; + dst_d.offset0 += /* [r1] */ + offsets[d] / blocks[d] * dst_d_blk.strides[d]; + } + + *md = dst_d; + + return success; +} + +int mkldnn_memory_desc_equal(const memory_desc_t *lhs, + const memory_desc_t *rhs) { + if (lhs == rhs) return 1; + if (any_null(lhs, rhs)) return 0; + return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); +} + +size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { + if (md == nullptr) return 0; + return memory_desc_wrapper(*md).size(); +} + +status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, + engine_t *engine, void *handle) { + if (any_null(memory, engine)) return invalid_arguments; + memory_desc_t z_md = types::zero_md(); + return engine->memory_create(memory, md ? md : &z_md, handle); +} + +status_t mkldnn_memory_get_memory_desc(const memory_t *memory, + const memory_desc_t **md) { + if (any_null(memory, md)) return invalid_arguments; + *md = memory->md(); + return success; +} + +status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { + if (any_null(memory, engine)) return invalid_arguments; + *engine = memory->engine(); + return success; +} + +status_t mkldnn_memory_get_data_handle(const memory_t *memory, + void **handle) { + if (any_null(handle)) + return invalid_arguments; + if (memory == nullptr) { + *handle = nullptr; + return success; + } + return memory->get_data_handle(handle); +} + +status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { + if (any_null(memory)) return invalid_arguments; + return memory->set_data_handle(handle); +} + +status_t mkldnn_memory_destroy(memory_t *memory) { + delete memory; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp new file mode 100644 index 0000000000..03dfee01ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_HPP +#define MEMORY_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" + +struct mkldnn_memory: public mkldnn::impl::c_compatible { + mkldnn_memory(mkldnn::impl::engine_t *engine, + const mkldnn::impl::memory_desc_t *md) + : engine_(engine), md_(*md) {} + virtual ~mkldnn_memory() {} + + /** allocates/initializes memory */ + virtual mkldnn::impl::status_t init() = 0; + + /** returns memory's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + /** returns memory's description */ + const mkldnn::impl::memory_desc_t *md() const { return &md_; } + + /** returns data handle */ + virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0; + + /** sets data handle */ + virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0; + + /** zeros padding */ + virtual mkldnn::impl::status_t zero_pad() const + { return mkldnn::impl::status::success; } + +protected: + mkldnn::impl::engine_t *engine_; + const mkldnn::impl::memory_desc_t md_; + +private: + mkldnn_memory() = delete; + mkldnn_memory(const mkldnn_memory &) = delete; + mkldnn_memory(mkldnn_memory &&) = delete; + mkldnn_memory &operator=(const mkldnn_memory &) = delete; + mkldnn_memory &operator=(mkldnn_memory &&) = delete; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp new file mode 100644 index 0000000000..8a99be33f3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp @@ -0,0 +1,212 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t fill_blocked(memory_desc_t &md, + std::initializer_list perm, + std::initializer_list inner_blks, + std::initializer_list inner_idxs) { + const bool ok = true + && perm.size() == (size_t)md.ndims + && inner_blks.size() == inner_idxs.size(); + if (!ok) return status::invalid_arguments; + + md.offset0 = 0; + + blocking_desc_t &blk = md.format_desc.blocking; + + dim_t block_size = 1; + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + + blk.inner_nblks = (int)inner_blks.size(); + + int iblk = 0; + for (const auto &b: inner_idxs) + blk.inner_idxs[iblk++] = b; + + iblk = 0; + for (const auto &b: inner_blks) { + int dim = blk.inner_idxs[iblk]; + block_size *= b; + blocks[dim] *= b; + blk.inner_blks[iblk++] = b; + } + + utils::array_set(md.padded_offsets, 0, md.ndims); + for (int d = 0; d < md.ndims; ++d) + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + + dim_t stride = block_size; + // if only we use C++14, the initializer_list would have rbegin()/rend()... + for (int d = 0; d < md.ndims; ++d) + stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; + + for (const auto &d: perm) { + if (md.padded_dims[d] == 0) { + blk.strides[d] = 1; + continue; + } + stride /= md.padded_dims[d] / blocks[d]; + blk.strides[d] = stride; + } + + assert(stride == block_size); + + return status::success; +} + +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag) +{ + using namespace format_tag; + + if (memory_desc.ndims == 0) return status::invalid_arguments; + +# define C(tag, ... /* perm, inner_blks, inner_idxs */) \ + case tag: return fill_blocked(memory_desc, __VA_ARGS__) + + switch (tag) { + C(a, {0}, {}, {}); + C(ab, {0, 1}, {}, {}); + C(abc, {0, 1, 2}, {}, {}); + C(abcd, {0, 1, 2, 3}, {}, {}); + C(abcde, {0, 1, 2, 3, 4}, {}, {}); + C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); + C(abdec, {0, 1, 3, 4, 2}, {}, {}); + C(acb, {0, 2, 1}, {}, {}); + C(acbde, {0, 2, 1, 3, 4}, {}, {}); + C(acdb, {0, 2, 3, 1}, {}, {}); + C(acdeb, {0, 2, 3, 4, 1}, {}, {}); + C(ba, {1, 0}, {}, {}); + C(bac, {1, 0, 2}, {}, {}); + C(bacd, {1, 0, 2, 3}, {}, {}); + C(bcda, {1, 2, 3, 0}, {}, {}); + C(cba, {2, 1, 0}, {}, {}); + C(cdba, {2, 3, 1, 0}, {}, {}); + C(cdeba, {2, 3, 4, 1, 0}, {}, {}); + C(decab, {3, 4, 2, 0, 1}, {}, {}); + + C(Abc4a, {0, 1, 2}, {4}, {0}); + C(aBc4b, {0, 1, 2}, {4}, {1}); + C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); + C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); + C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); + C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); + C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); + C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); + C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); + C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); + C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); + C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); + C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); + C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); + C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); + C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); + C(Acb4a, {0, 2, 1}, {4}, {0}); + C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); + C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); + + C(Abc16a, {0, 1, 2}, {16}, {0}); + C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); + C(aBc16b, {0, 1, 2}, {16}, {1}); + C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); + C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); + C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); + C(aBc8b, {0, 1, 2}, {8}, {1}); + C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); + C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); + C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); + C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); + C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); + C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); + C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); + C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); + C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); + C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); + C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); + C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); + C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); + C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); + C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); + C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); + C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); + C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); + C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); + C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); + C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); + C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); + C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); + C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); + C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); + C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); + C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); + C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); + C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); + C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); + C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); + C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); + C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); + C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); + C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); + C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); + C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); + C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); + C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); + C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); + C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); + C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); + C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); + C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); + C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); + C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); + C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); + C(Acb16a, {0, 2, 1}, {16}, {0}); + C(Acb8a, {0, 2, 1}, {8}, {0}); + C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); + C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); + C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); + C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); + C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); + C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); + C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); + C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); + default: break; + } + +#undef C + + return status::invalid_arguments; +} + +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp new file mode 100644 index 0000000000..1758f9078a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_DESC_WRAPPER_HPP +#define MEMORY_DESC_WRAPPER_HPP + +#include + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +/** thin wrapper class over \struct memory_desc_t which allows easy + * manipulations with underlying C structure, which is taken by reference */ +struct memory_desc_wrapper: public c_compatible { + const memory_desc_t *md_; + + /** constructor which takes a reference to a constant underlying C memory + * descriptor \param md */ + memory_desc_wrapper(const memory_desc_t *md): md_(md) {} + memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {} + + /* implementing attributes */ + int ndims() const { return md_->ndims; } + const dims_t &dims() const { return md_->dims; } + data_type_t data_type() const { return md_->data_type; } + + const dims_t &padded_dims() const { return md_->padded_dims; } + const dims_t &padded_offsets() const { return md_->padded_offsets; } + dim_t offset0() const { return md_->offset0; } + + format_kind_t format_kind() const { return md_->format_kind; } + + bool is_blocking_desc() const + { return format_kind() == format_kind::blocked; } + bool is_wino_desc() const + { return format_kind() == format_kind::wino; } + bool is_rnn_packed_desc() const + { return format_kind() == format_kind::rnn_packed; } + + const blocking_desc_t &blocking_desc() const { + assert(is_blocking_desc()); + return md_->format_desc.blocking; + } + const wino_desc_t &wino_desc() const { + assert(is_wino_desc()); + return md_->format_desc.wino_desc; + } + const rnn_packed_desc_t &rnn_packed_desc() const { + assert(is_rnn_packed_desc()); + return md_->format_desc.rnn_packed_desc; + } + + const memory_extra_desc_t &extra() const { return md_->extra; } + + /* some useful function */ + + /** returns the number of elements including padding if \param with_padding + * is true, and the number of data elements otherwise */ + dim_t nelems(bool with_padding = false) const { + if (is_zero()) return 0; + return utils::array_product( + with_padding ? padded_dims() : dims(), ndims()); + } + + /** returns true if memory descriptor is zero */ + bool is_zero() const { return ndims() == 0; } + + /** returns true if memory descriptor contains zero as one of its dim */ + bool has_zero_dim() const { return nelems() == 0; } + + /** return the size of data type (a shortcut) */ + size_t data_type_size() const + { return types::data_type_size(data_type()); } + + /** return the size of data type of additional buffer */ + size_t additional_buffer_data_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) + return sizeof(int32_t); + return 0; + } + + /** return true if memory format has additional buffer */ + bool is_additional_buffer() const { + return (extra().flags & memory_extra_flags::compensation_conv_s8s8); + } + + /** returns the size of additional buffer */ + size_t additional_buffer_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) { + int cmask = extra().compensation_mask; + assert(cmask == 1 || cmask == 3); + dim_t prod = 1; + for (int d = 0; d < ndims(); ++d) + if (cmask & (1<(max_size, + padded_dims()[d] / blocks[d] * bd.strides[d]); + + if (max_size == 1 && bd.inner_nblks != 0) { + max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); + } + + return max_size * data_type_size() + additional_buffer_size(); + } + } + + /** returns true if data is dense in memory */ + bool is_dense(bool with_padding = false) const { + if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + return nelems(with_padding) * data_type_size() == size(); + } + + /** returns true if memory desc is fully defined */ + bool is_defined() const { return format_kind() != format_kind::any; } + + /** returns true if the only (potentially) padded dim is \param dim */ + bool only_padded_dim(int dim) const { + for (int d = 0; d < ndims(); ++d) + if (d != dim && dims()[d] != padded_dims()[d]) + return false; + return true; + } + + /** returns true if memory desc has blocked layout and block dims are 1s */ + bool is_plain() const { + if (!is_blocking_desc()) return false; + return blocking_desc().inner_nblks == 0; + } + + /** returns overall block sizes */ + void compute_blocks(dims_t blocks) const { + if (!is_blocking_desc()) { + utils::array_set(blocks, 0, ndims()); + return; + } + + utils::array_set(blocks, 1, ndims()); + + const auto &bd = blocking_desc(); + for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) + blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; + } + + /* comparison section */ + + bool operator==(const memory_desc_wrapper &rhs) const + { return *this->md_ == *rhs.md_; } + bool operator!=(const memory_desc_wrapper &rhs) const + { return !operator==(rhs); } + bool operator==(const memory_desc_t &rhs) const + { return operator==(memory_desc_wrapper(rhs)); } + bool operator!=(const memory_desc_t &rhs) const + { return !operator==(rhs); } + + /** returns true if data (w/o padding if with_padding == false and w/ + * padding otherwise) have the same physical structure, i.e. dimensions, + * strides, and blocked structure. Depending on with_data_type flag + * data_type is taken or not taken into account. dim_start allows to check + * similarity for the logical part of data [dim_start .. ndims()]. + * CAUTION: format kind any and undef are not similar to whatever, hence the + * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ + /* TODO: revise */ + bool similar_to(const memory_desc_wrapper &rhs, + bool with_padding = true, bool with_data_type = true, + int dim_start = 0) const; + + /** returns true if one memory can be reordered to another */ + bool consistent_with(const memory_desc_wrapper &rhs) const; + + /** returns true if the memory desc corresponds to the given format tag and + * strides. + * @sa memory_desc_matches_tag */ + bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { + return memory_desc_matches_tag(*md_, tag, strides); + } + + /** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ + template + format_tag_t matches_one_of_tag(Tags ...tags) const { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(*md_, tag)) + return tag; + } + return format_tag::undef; + } + + /* offset section */ + + /** returns physical offset by logical one. logical offset is represented by + * an array \param pos. if \param is_pos_padded is true \param pos + * represents the position in already padded area */ + dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + const blocking_desc_t &blk = blocking_desc(); + + dims_t pos_copy = {0}; + for (int d = 0; d < ndims(); ++d) + pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); + + dim_t phys_offset = offset0(); + + if (blk.inner_nblks > 0) { + dim_t blk_stride = 1; + for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { + const int d = blk.inner_idxs[iblk]; + const dim_t p = pos_copy[d] % blk.inner_blks[iblk]; + + phys_offset += p * blk_stride; + + pos_copy[d] /= blk.inner_blks[iblk]; + + blk_stride *= blk.inner_blks[iblk]; + } + } + + for (int d = 0; d < ndims(); ++d) { + const dim_t p = pos_copy[d]; + phys_offset += p * blk.strides[d]; + } + + return phys_offset; + } + + /** returns physical offset by logical one. logical offset is represented by + * a scalar \param l_offset. if \param is_pos_padded is true, \param + * l_offset represents logical offset in already padded area */ + dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + dims_t pos; + for (int rd = 0; rd < ndims(); ++rd) { + const int d = ndims() - 1 - rd; + const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d]; + pos[d] = l_offset % cur_dim; + l_offset /= cur_dim; + } + return off_v(pos, is_pos_padded); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) */ + template + dim_t off(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, false); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) in already + * padded area */ + template + dim_t off_padding(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, true); + } + + /** returns physical offset by logical one. Logical offset is represented by + * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a + * user responsibility to adjust the result to get offset within blocks */ + template + dim_t blk_off(Args... args) const { + return _blk_off(args...); + } + + template + dim_t blk_off(T xn, Args... args) const { + return skip_first + ? blk_off(args...) + : blk_off(xn, args...); + } + + /* static functions section */ + /* TODO: replace with non-static, once md_ becomes non-const ref */ + + static status_t compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag); + +private: + /* TODO: put logical_offset in utils */ + template + dim_t logical_offset(T x0) const { return x0; } + + template + dim_t logical_offset(T xn, Args... args) const { + const size_t n_args = sizeof...(args); + return xn * utils::array_product( + &dims()[ndims() - n_args]) + logical_offset(args...); + } + + template + dim_t _blk_off() const { return offset0(); } + + template + dim_t _blk_off(T xc, Args ...args) const { + assert(is_blocking_desc()); + constexpr int dc = ORIG_LEN - sizeof...(args) - 1; + return xc * blocking_desc().strides[dc] + + _blk_off(args...); + } +}; + +inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, + bool with_padding, bool with_data_type, int dim_start) const { + using namespace utils; + + if (one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + if (is_wino_desc() || is_rnn_packed_desc()) + return false; + + const int ds = dim_start; + const auto &blk = blocking_desc(); + const auto &r_blk = rhs.blocking_desc(); + + return ndims() == rhs.ndims() + && dim_start <= ndims() /* guard */ + && format_kind() == rhs.format_kind() + && IMPLICATION(with_data_type, data_type() == rhs.data_type()) + && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) + && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && blk.inner_nblks == r_blk.inner_nblks + && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) + && IMPLICATION(with_padding, true + && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds, + ndims() - ds) + && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds, + ndims() - ds)); +} + +inline bool memory_desc_wrapper::consistent_with( + const memory_desc_wrapper &rhs) const { + if (ndims() == rhs.ndims()) { + for (int d = 0; d < ndims(); ++d) { + if (dims()[d] != rhs.dims()[d]) return false; + } + return true; + } else { + /* TODO: revise. + * is the following possible? + * [1, a, b] <--reorder--> [a, b] + * [a, 1, b] <--reorder--> [a, b] + * not, at least for now */ + return false; + } +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp new file mode 100644 index 0000000000..ec077b308c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp @@ -0,0 +1,295 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_TRACKING_HPP +#define MEMORY_TRACKING_HPP + +#include +#include + +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace memory_tracking { + +/* Memory tracking capabilities + * + * The main purpose of this header file is to provide uniform way to register + * required memory for a scratchpad at a primitive descriptor creation time + * and then easily access it having only the base address of the scratchpad. + * + * Primitives might contain multiple disjoint parts that require temporary + * buffers (known as scratchpad) during their execution. A primitive descriptor + * should summarize all the needs into one single number -- the buffer size + * that would be requested from a user. At execution time, the corresponding + * primitive will receive a base pointer to a scratchpad. It then needs to + * provide each part of algorithm the corresponding piece of memory. Three main + * challenges here are: + * 1. Track correct offset (from the base scratchpad address) for each piece + * 2. Algorithm might require that different memory pieces to be aligned, so + * the scratchpad size is no more just a sum of size of the corresponding + * subparts. + * 3. While a primitive is responsible for its scratchpad, the implementation + * might use some other basic blocks (e.g. cpu_reducer) that also require + * scratchpad memory. So there should be a simple way of passing the + * information back and force between the main algorithm (a primitive) and + * auxiliary stuff that lives completely separately from it (e.g. reducer). + * + * To address these challenges this header file provides 3 structures: + * 1. registry_t -- the class the stores the information about requested + * memory. The information includes required size and desired + * alignment for each piece. This class is also responsible + * for computing the right offset to a given piece using the + * base pointer. + * This class is basically a ledger with all entries. + * Lives in primitive descriptors. + * + * 2. registrar_t -- the interface to a registry_t to book memory. Used at + * primitive descriptor creation time only. Contains a + * reference to the corresponding *mutable* registry. + * Always modifiable. + * Allows chaining (using prefixes). + * + * 3. grantor_t -- the interface to a registry_t to access memory. Used at + * primitive execution time only. Contains a reference to + * the corresponding *constant* registry and base pointer. + * Always constant. + * Allows chaining (using prefixes). + * + * Both registrar_t and grantor_t allow chaining with extra prefix provided. + * The feature is useful when a primitive offload a part of computations to + * some other primitives which require their own scratchpad space + * (e.g. reducer). Prefixes are used to avoid key collision in cases when + * multiple sub-primitive (e.g. multiple reducers) are used. + * + * A short example below demonstrates how to use aforementioned classes. In it + * the main primitive is convolution that uses scratchpad for keeping padded + * bias. It also needs a reducer, that needs its own space as well. + * + * ``` c++ + * struct reducer_t { + * static void init(registrar_t &scratchpad) { + * // preserve space for the reduction (one page aligned) + * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096); + * } + * + * void exec(const grantor_t &scratchpad) { + * // get the pointer to preserved space. scratchpad came from + * // upper primitive (convolution in this example) + * auto space = scratchpad.get(key_reducer_space); + * + * space[:] += ...; + * } + * }; + * + * struct conv_t { + * struct pd_t { + * void init() { + * registrar_t scratchpad(scratchpad_registry_); + * + * // preserve a space for padded bias (using default alignment) + * scratchpad.book(key_conv_padded_bias, 128); + * + * // create a proxy registrar for the reducer All entries made + * // by reducer would live in convolution's registry, but would + * // have their own `prefix`, so no interference with conv's + * // buffers. + * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); + * + * reducer_t::init(reducer_scratchpad); + * } + * + * registry_t scratchpad_registry_; + * } + * + * void exec() { + * // get the base pointer to a scratchpad memory from a user + * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD); + * + * // create a grantor to the scratchpad (and provide the base + * // pointer). + * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); + * + * // access the padded_bias (need only key name and the grantor) + * auto padded_bias = scratchpad.get(key_conv_padded_bias); + * + * // to give the `right` grantor to reducer we need to add the + * // corresponding prefix, so that reducer would be able to access + * // its keys. The call is very similar to the one in pd_t::init + * // with only difference in types: grantor_t vs registrar_t. + * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); + * reducer->exec(reducer_scratchpad); + * } + * }; + * ``` + */ + + +/* namespace with common keys and prefixes */ +namespace names { +enum { + key_none = 0, + key_bnorm_tmp_mean, + key_bnorm_tmp_var, + key_bnorm_tmp_diff_ss, + key_bnorm_tmp_stats, + key_bnorm_reduction, + key_concat_iptrs, + key_concat_istrides, + key_concat_nelems, + key_concat_optrs, + key_conv_adjusted_scales, + key_conv_bia_reduction, + key_conv_gemm_col, + key_conv_gemm_imtr, + key_conv_int_dat_in_acc_dt, + key_conv_padded_bias, + key_conv_rtus_space, + key_conv_tr_diff_dst, + key_conv_tr_diff_dst_bctx, + key_conv_tr_src, + key_conv_tr_src_bctx, + key_conv_wei_reduction, + key_conv_wei_bia_reduction, + key_conv_wei_bia_reduction_bctx, + key_iprod_int_dat_in_acc_dt, + key_reducer_space, + key_reducer_space_bctx, + key_reorder_wino_plain, + key_reorder_wino_transform_space, + key_reorder_rnn_weights_quantization, + key_reorder_rnn_weights_reduction, + key_rnn_space, + key_rnn_ptrs_bia, + key_rnn_ptrs_wei_layer, + key_rnn_ptrs_wei_iter, + key_softmax_reduction, + key_wino_U, + key_wino_V, + key_wino_M, + key_barrier, +}; + +enum { + prefix_none = 0, + prefix_reducer_bia, + prefix_reducer_wei, +}; +} + +// level 0: 00 00 00 xxx +// level 1: 00 00 aa xxx +// level 2: 00 aa bb xxx +// level 3: aa bb cc xxx +// max # of levels: 3 + 1 (base_level) +// here: +// xxx : [1 .. MAX_KEY) : key +// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 + +using key_t = uint32_t; +enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), }; + +/// generates global key based on a prefix and a local key +inline key_t make_key(key_t prefix, key_t key) { return prefix + key; } + +/// generates global prefix based on the global parent and the local ones +inline key_t make_prefix(key_t parent_prefix, key_t prefix) +{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; } + +struct registrar_t; +struct grantor_t; + +struct registry_t { + void book(const key_t &key, size_t size, size_t alignment) { + if (size == 0) return; + assert(offset_map_.count(key) == 0); + + size = utils::rnd_up(size, minimal_alignment); + alignment = nstl::max(alignment, minimal_alignment); + offset_map_[key] = entry_t{size_, size, alignment}; + + size_ += size + alignment - minimal_alignment; + } + + void *get(const key_t &key, void *base_ptr) const { + if (base_ptr == nullptr) { assert(size() == 0); return nullptr; } + if (offset_map_.count(key) != 1) return nullptr; + + const auto &e = offset_map_.at(key); + base_ptr = utils::align_ptr(base_ptr, minimal_alignment); + char *ptr = (char *)base_ptr + e.offset; + return utils::align_ptr(ptr, e.alignment); + } + + size_t size() const + { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; } + + registrar_t registrar(); + grantor_t grantor(void *base_ptr) const; + +protected: + enum { minimal_alignment = 64 }; + struct entry_t { size_t offset, size, alignment; }; + + std::unordered_map offset_map_; + size_t size_ = 0; +}; + +struct registrar_t { + enum { default_alignment = 64 }; + + registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {} + registrar_t(registrar_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) {} + + void book(const key_t &key, size_t size, + size_t alignment = default_alignment) + { registry_.book(make_key(prefix_, key), size, alignment); } + +protected: + registry_t ®istry_; + const key_t prefix_; +}; + +struct grantor_t { + grantor_t(const registry_t ®istry, void *base_ptr) + : registry_(registry), prefix_(0), base_ptr_(base_ptr) {} + grantor_t(const grantor_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) + , base_ptr_(parent.base_ptr_) {} + + template T *get(const key_t &key) const + { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); } + +protected: + const registry_t ®istry_; + const key_t prefix_; + void *base_ptr_; +}; + +inline registrar_t registry_t::registrar() { return registrar_t(*this); } +inline grantor_t registry_t::grantor(void *base_ptr) const +{ return grantor_t(*this, base_ptr); } + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp new file mode 100644 index 0000000000..2ef4a8fddc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#define DPRINT(...) do { \ + int l = snprintf(str + written_len, str_len, __VA_ARGS__); \ + if (l < 0) return l; \ + if ((size_t)l >= str_len) return -1; \ + written_len += l; str_len -= l; \ +} while(0) + +int mkldnn_md2fmt_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1u) + return -1; + + int written_len = 0; + + if (mdesc == nullptr) { + DPRINT("%s::%s::", + mkldnn_dt2str(data_type::undef), + mkldnn_fmt_kind2str(format_kind::undef)); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + DPRINT("%s:", mkldnn_dt2str(md.data_type())); + + bool padded_dims = false, padded_offsets = false; + for (int d = 0; d < md.ndims(); ++d) { + if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true; + if (md.padded_offsets()[d] != 0) padded_offsets = true; + } + bool offset0 = md.offset0(); + DPRINT("%s%s%s:", + padded_dims ? "p" : "", + padded_offsets ? "o" : "", + offset0 ? "0" : ""); + + DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind())); + + if (!md.is_blocking_desc()) { + /* TODO: extend */ + DPRINT("%s:", ""); + } else { + const auto &blk = md.blocking_desc(); + + dims_t blocks; + md.compute_blocks(blocks); + + char dim_chars[MKLDNN_MAX_NDIMS + 1]; + + bool plain = true; + for (int d = 0; d < md.ndims(); ++d) { + dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; + if (blocks[d] != 1) plain = false; + } + + dims_t strides; + utils::array_copy(strides, blk.strides, md.ndims()); + utils::simultaneous_sort(strides, dim_chars, md.ndims(), + [](dim_t a, dim_t b) { return b - a; }); + + dim_chars[md.ndims()] = '\0'; + DPRINT("%s", dim_chars); + + if (!plain) { + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + DPRINT("%d%c", (int)blk.inner_blks[iblk], + 'a' + (char)blk.inner_idxs[iblk]); + } + } + + DPRINT("%s", ":"); + } + + DPRINT("f%lx", (long)md.extra().flags); + + return written_len; +} + +int mkldnn_md2dim_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1) + return -1; + + int written_len = 0; + + if (mdesc == nullptr || mdesc->ndims == 0) { + DPRINT("%s", ""); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + for (int d = 0; d < md.ndims() - 1; ++d) + DPRINT("%" PRId64 "x", md.dims()[d]); + DPRINT("%" PRId64, md.dims()[md.ndims() - 1]); + + return written_len; +} + +#undef DPRINT diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp new file mode 100644 index 0000000000..16a8f7ea5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp @@ -0,0 +1,365 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#include + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +const char *mkldnn_status2str(mkldnn_status_t v) { + if (v == mkldnn_success) return "success"; + if (v == mkldnn_out_of_memory) return "out_of_memory"; + if (v == mkldnn_try_again) return "try_again"; + if (v == mkldnn_invalid_arguments) return "invalid_arguments"; + if (v == mkldnn_not_ready) return "not_ready"; + if (v == mkldnn_unimplemented) return "unimplemented"; + if (v == mkldnn_iterator_ends) return "iterator_ends"; + if (v == mkldnn_runtime_error) return "runtime_error"; + if (v == mkldnn_not_required) return "not_required"; + assert(!"unknown status"); + return "unknown status"; +} + +const char *mkldnn_dt2str(mkldnn_data_type_t v) { + if (v == mkldnn_data_type_undef) return "undef"; + if (v == mkldnn_f32) return "f32"; + if (v == mkldnn_s32) return "s32"; + if (v == mkldnn_s8) return "s8"; + if (v == mkldnn_u8) return "u8"; + assert(!"unknown dt"); + return "unknown dt"; +} + +const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) { + if (v == mkldnn_format_kind_undef) return "undef"; + if (v == mkldnn_format_kind_any) return "any"; + if (v == mkldnn_blocked) return "blocked"; + if (v == mkldnn_format_kind_wino) return "wino"; + if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed"; + assert(!"unknown fmt_kind"); + return "unknown fmt_kind"; +} + +const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) { + if (v == mkldnn_format_tag_undef) return "undef"; + if (v == mkldnn_format_tag_any) return "format_tag_any"; + if (v == mkldnn_a) return "a"; + if (v == mkldnn_ab) return "ab"; + if (v == mkldnn_abc) return "abc"; + if (v == mkldnn_abcd) return "abcd"; + if (v == mkldnn_abcde) return "abcde"; + if (v == mkldnn_abcdef) return "abcdef"; + if (v == mkldnn_abdec) return "abdec"; + if (v == mkldnn_acb) return "acb"; + if (v == mkldnn_acbde) return "acbde"; + if (v == mkldnn_acdb) return "acdb"; + if (v == mkldnn_acdeb) return "acdeb"; + if (v == mkldnn_ba) return "ba"; + if (v == mkldnn_bac) return "bac"; + if (v == mkldnn_bacd) return "bacd"; + if (v == mkldnn_bcda) return "bcda"; + if (v == mkldnn_cba) return "cba"; + if (v == mkldnn_cdba) return "cdba"; + if (v == mkldnn_cdeba) return "cdeba"; + if (v == mkldnn_decab) return "decab"; + if (v == mkldnn_Abc16a) return "Abc16a"; + if (v == mkldnn_ABc16a16b) return "ABc16a16b"; + if (v == mkldnn_aBc16b) return "aBc16b"; + if (v == mkldnn_ABc16b16a) return "ABc16b16a"; + if (v == mkldnn_Abc4a) return "Abc4a"; + if (v == mkldnn_aBc4b) return "aBc4b"; + if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b"; + if (v == mkldnn_ABc4b4a) return "ABc4b4a"; + if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a"; + if (v == mkldnn_ABc8a8b) return "ABc8a8b"; + if (v == mkldnn_aBc8b) return "aBc8b"; + if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b"; + if (v == mkldnn_ABc8b8a) return "ABc8b8a"; + if (v == mkldnn_Abcd16a) return "Abcd16a"; + if (v == mkldnn_ABcd16a16b) return "ABcd16a16b"; + if (v == mkldnn_aBcd16b) return "aBcd16b"; + if (v == mkldnn_ABcd16b16a) return "ABcd16b16a"; + if (v == mkldnn_aBCd16b16c) return "aBCd16b16c"; + if (v == mkldnn_aBCd16c16b) return "aBCd16c16b"; + if (v == mkldnn_Abcd4a) return "Abcd4a"; + if (v == mkldnn_aBcd4b) return "aBcd4b"; + if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b"; + if (v == mkldnn_ABcd4b4a) return "ABcd4b4a"; + if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c"; + if (v == mkldnn_aBCd4c4b) return "aBCd4c4b"; + if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a"; + if (v == mkldnn_ABcd8a8b) return "ABcd8a8b"; + if (v == mkldnn_aBcd8b) return "aBcd8b"; + if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b"; + if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b"; + if (v == mkldnn_ABcd8b8a) return "ABcd8b8a"; + if (v == mkldnn_aBCd8b8c) return "aBCd8b8c"; + if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c"; + if (v == mkldnn_aBCd8c8b) return "aBCd8c8b"; + if (v == mkldnn_Abcde16a) return "Abcde16a"; + if (v == mkldnn_ABcde16a16b) return "ABcde16a16b"; + if (v == mkldnn_aBcde16b) return "aBcde16b"; + if (v == mkldnn_ABcde16b16a) return "ABcde16b16a"; + if (v == mkldnn_aBCde16b16c) return "aBCde16b16c"; + if (v == mkldnn_aBCde16c16b) return "aBCde16c16b"; + if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c"; + if (v == mkldnn_Abcde4a) return "Abcde4a"; + if (v == mkldnn_aBcde4b) return "aBcde4b"; + if (v == mkldnn_ABcde4b4a) return "ABcde4b4a"; + if (v == mkldnn_aBCde4b4c) return "aBCde4b4c"; + if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c"; + if (v == mkldnn_aBCde4c4b) return "aBCde4c4b"; + if (v == mkldnn_Abcde8a) return "Abcde8a"; + if (v == mkldnn_ABcde8a8b) return "ABcde8a8b"; + if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b"; + if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b"; + if (v == mkldnn_ABcde8b8a) return "ABcde8b8a"; + if (v == mkldnn_aBCde8b8c) return "aBCde8b8c"; + if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c"; + if (v == mkldnn_aBCde8c8b) return "aBCde8c8b"; + if (v == mkldnn_aBcdef16b) return "aBcdef16b"; + if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c"; + if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b"; + if (v == mkldnn_aBcdef4b) return "aBcdef4b"; + if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b"; + if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c"; + if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c"; + if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b"; + if (v == mkldnn_aBdc16b) return "aBdc16b"; + if (v == mkldnn_aBdc4b) return "aBdc4b"; + if (v == mkldnn_aBdc8b) return "aBdc8b"; + if (v == mkldnn_aBdec16b) return "aBdec16b"; + if (v == mkldnn_aBdec4b) return "aBdec4b"; + if (v == mkldnn_aBdec8b) return "aBdec8b"; + if (v == mkldnn_aBdefc16b) return "aBdefc16b"; + if (v == mkldnn_aBdefc4b) return "aBdefc4b"; + if (v == mkldnn_aBdefc8b) return "aBdefc8b"; + if (v == mkldnn_Acb16a) return "Acb16a"; + if (v == mkldnn_Acb4a) return "Acb4a"; + if (v == mkldnn_Acb8a) return "Acb8a"; + if (v == mkldnn_aCBd16b16c) return "aCBd16b16c"; + if (v == mkldnn_aCBde16b16c) return "aCBde16b16c"; + if (v == mkldnn_Acdb16a) return "Acdb16a"; + if (v == mkldnn_Acdb4a) return "Acdb4a"; + if (v == mkldnn_Acdb8a) return "Acdb8a"; + if (v == mkldnn_Acdeb16a) return "Acdeb16a"; + if (v == mkldnn_Acdeb4a) return "Acdeb4a"; + if (v == mkldnn_Acdeb8a) return "Acdeb8a"; + if (v == mkldnn_BAc16a16b) return "BAc16a16b"; + if (v == mkldnn_BAcd16a16b) return "BAcd16a16b"; + if (v == mkldnn_format_tag_last) return "format_tag_last"; + if (v == mkldnn_x) return "x"; + if (v == mkldnn_nc) return "nc"; + if (v == mkldnn_cn) return "cn"; + if (v == mkldnn_ncw) return "ncw"; + if (v == mkldnn_nwc) return "nwc"; + if (v == mkldnn_nchw) return "nchw"; + if (v == mkldnn_nhwc) return "nhwc"; + if (v == mkldnn_chwn) return "chwn"; + if (v == mkldnn_ncdhw) return "ncdhw"; + if (v == mkldnn_ndhwc) return "ndhwc"; + if (v == mkldnn_oi) return "oi"; + if (v == mkldnn_io) return "io"; + if (v == mkldnn_oiw) return "oiw"; + if (v == mkldnn_wio) return "wio"; + if (v == mkldnn_oihw) return "oihw"; + if (v == mkldnn_hwio) return "hwio"; + if (v == mkldnn_ihwo) return "ihwo"; + if (v == mkldnn_iohw) return "iohw"; + if (v == mkldnn_oidhw) return "oidhw"; + if (v == mkldnn_dhwio) return "dhwio"; + if (v == mkldnn_goiw) return "goiw"; + if (v == mkldnn_goihw) return "goihw"; + if (v == mkldnn_hwigo) return "hwigo"; + if (v == mkldnn_giohw) return "giohw"; + if (v == mkldnn_goidhw) return "goidhw"; + if (v == mkldnn_tnc) return "tnc"; + if (v == mkldnn_ntc) return "ntc"; + if (v == mkldnn_ldsnc) return "ldsnc"; + if (v == mkldnn_ldigo) return "ldigo"; + if (v == mkldnn_ldgoi) return "ldgoi"; + if (v == mkldnn_ldgo) return "ldgo"; + if (v == mkldnn_nCdhw16c) return "nCdhw16c"; + if (v == mkldnn_nCdhw4c) return "nCdhw4c"; + if (v == mkldnn_nCdhw8c) return "nCdhw8c"; + if (v == mkldnn_nChw16c) return "nChw16c"; + if (v == mkldnn_nChw4c) return "nChw4c"; + if (v == mkldnn_nChw8c) return "nChw8c"; + if (v == mkldnn_nCw16c) return "nCw16c"; + if (v == mkldnn_nCw4c) return "nCw4c"; + if (v == mkldnn_nCw8c) return "nCw8c"; + if (v == mkldnn_IOw16o16i) return "IOw16o16i"; + if (v == mkldnn_OIw16i16o) return "OIw16i16o"; + if (v == mkldnn_OIw16o16i) return "OIw16o16i"; + if (v == mkldnn_Oiw16o) return "Oiw16o"; + if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i"; + if (v == mkldnn_OIw4i4o) return "OIw4i4o"; + if (v == mkldnn_Oiw4o) return "Oiw4o"; + if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i"; + if (v == mkldnn_OIw8i8o) return "OIw8i8o"; + if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o"; + if (v == mkldnn_OIw8o8i) return "OIw8o8i"; + if (v == mkldnn_Owi16o) return "Owi16o"; + if (v == mkldnn_Owi4o) return "Owi4o"; + if (v == mkldnn_Owi8o) return "Owi8o"; + if (v == mkldnn_IOhw16o16i) return "IOhw16o16i"; + if (v == mkldnn_Ohwi16o) return "Ohwi16o"; + if (v == mkldnn_Ohwi4o) return "Ohwi4o"; + if (v == mkldnn_Ohwi8o) return "Ohwi8o"; + if (v == mkldnn_OIhw16i16o) return "OIhw16i16o"; + if (v == mkldnn_OIhw16o16i) return "OIhw16o16i"; + if (v == mkldnn_Oihw16o) return "Oihw16o"; + if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i"; + if (v == mkldnn_OIhw4i4o) return "OIhw4i4o"; + if (v == mkldnn_Oihw4o) return "Oihw4o"; + if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i"; + if (v == mkldnn_OIhw8i8o) return "OIhw8i8o"; + if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o"; + if (v == mkldnn_OIhw8o8i) return "OIhw8o8i"; + if (v == mkldnn_Odhwi16o) return "Odhwi16o"; + if (v == mkldnn_Odhwi4o) return "Odhwi4o"; + if (v == mkldnn_Odhwi8o) return "Odhwi8o"; + if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o"; + if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i"; + if (v == mkldnn_Oidhw16o) return "Oidhw16o"; + if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o"; + if (v == mkldnn_Oidhw4o) return "Oidhw4o"; + if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i"; + if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o"; + if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i"; + if (v == mkldnn_Goiw16g) return "Goiw16g"; + if (v == mkldnn_gIOw16o16i) return "gIOw16o16i"; + if (v == mkldnn_gOIw16i16o) return "gOIw16i16o"; + if (v == mkldnn_gOIw16o16i) return "gOIw16o16i"; + if (v == mkldnn_gOiw16o) return "gOiw16o"; + if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i"; + if (v == mkldnn_gOIw4i4o) return "gOIw4i4o"; + if (v == mkldnn_gOiw4o) return "gOiw4o"; + if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i"; + if (v == mkldnn_gOIw8i8o) return "gOIw8i8o"; + if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o"; + if (v == mkldnn_gOIw8o8i) return "gOIw8o8i"; + if (v == mkldnn_gOwi16o) return "gOwi16o"; + if (v == mkldnn_gOwi4o) return "gOwi4o"; + if (v == mkldnn_gOwi8o) return "gOwi8o"; + if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i"; + if (v == mkldnn_gOhwi16o) return "gOhwi16o"; + if (v == mkldnn_gOhwi4o) return "gOhwi4o"; + if (v == mkldnn_gOhwi8o) return "gOhwi8o"; + if (v == mkldnn_Goihw16g) return "Goihw16g"; + if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o"; + if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i"; + if (v == mkldnn_gOihw16o) return "gOihw16o"; + if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i"; + if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i"; + if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o"; + if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i"; + if (v == mkldnn_gOihw4o) return "gOihw4o"; + if (v == mkldnn_Goihw8g) return "Goihw8g"; + if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i"; + if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o"; + if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o"; + if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i"; + if (v == mkldnn_gOdhwi16o) return "gOdhwi16o"; + if (v == mkldnn_gOdhwi4o) return "gOdhwi4o"; + if (v == mkldnn_gOdhwi8o) return "gOdhwi8o"; + if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o"; + if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i"; + if (v == mkldnn_gOidhw16o) return "gOidhw16o"; + if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o"; + if (v == mkldnn_gOidhw4o) return "gOidhw4o"; + if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i"; + if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o"; + if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i"; + assert(!"unknown fmt_tag"); + return "unknown fmt_tag"; +} + +const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) { + if (v == mkldnn_prop_kind_undef) return "undef"; + if (v == mkldnn_forward_training) return "forward_training"; + if (v == mkldnn_forward_inference) return "forward_inference"; + if (v == mkldnn_forward_scoring) return "forward_scoring"; + if (v == mkldnn_forward) return "forward"; + if (v == mkldnn_backward) return "backward"; + if (v == mkldnn_backward_data) return "backward_data"; + if (v == mkldnn_backward_weights) return "backward_weights"; + if (v == mkldnn_backward_bias) return "backward_bias"; + assert(!"unknown prop_kind"); + return "unknown prop_kind"; +} + +const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { + if (v == mkldnn_undefined_primitive) return "undef"; + if (v == mkldnn_reorder) return "reorder"; + if (v == mkldnn_shuffle) return "shuffle"; + if (v == mkldnn_concat) return "concat"; + if (v == mkldnn_sum) return "sum"; + if (v == mkldnn_convolution) return "convolution"; + if (v == mkldnn_deconvolution) return "deconvolution"; + if (v == mkldnn_eltwise) return "eltwise"; + if (v == mkldnn_softmax) return "softmax"; + if (v == mkldnn_pooling) return "pooling"; + if (v == mkldnn_lrn) return "lrn"; + if (v == mkldnn_batch_normalization) return "batch_normalization"; + if (v == mkldnn_inner_product) return "inner_product"; + if (v == mkldnn_rnn) return "rnn"; + assert(!"unknown prim_kind"); + return "unknown prim_kind"; +} + +const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { + if (v == mkldnn_alg_kind_undef) return "undef"; + if (v == mkldnn_convolution_direct) return "convolution_direct"; + if (v == mkldnn_convolution_winograd) return "convolution_winograd"; + if (v == mkldnn_convolution_auto) return "convolution_auto"; + if (v == mkldnn_deconvolution_direct) return "deconvolution_direct"; + if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd"; + if (v == mkldnn_eltwise_relu) return "eltwise_relu"; + if (v == mkldnn_eltwise_tanh) return "eltwise_tanh"; + if (v == mkldnn_eltwise_elu) return "eltwise_elu"; + if (v == mkldnn_eltwise_square) return "eltwise_square"; + if (v == mkldnn_eltwise_abs) return "eltwise_abs"; + if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt"; + if (v == mkldnn_eltwise_linear) return "eltwise_linear"; + if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu"; + if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu"; + if (v == mkldnn_eltwise_logistic) return "eltwise_logistic"; + if (v == mkldnn_pooling_max) return "pooling_max"; + if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; + if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; + if (v == mkldnn_pooling_avg) return "pooling_avg"; + if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; + if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; + if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; + if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; + if (v == mkldnn_vanilla_gru) return "vanilla_gru"; + if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset"; + assert(!"unknown alg_kind"); + return "unknown alg_kind"; +} + +const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) { + if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right"; + if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left"; + if (v == mkldnn_bidirectional_concat) return "bidirectional_concat"; + if (v == mkldnn_bidirectional_sum) return "bidirectional_sum"; + if (v == mkldnn_unidirectional) return "unidirectional"; + assert(!"unknown rnn_direction"); + return "unknown rnn_direction"; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp new file mode 100644 index 0000000000..7e5789e2c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_HPP +#define MKLDNN_THREAD_HPP + +#include "utils.hpp" +#include "z_magic.hpp" + +#define MKLDNN_THR_SEQ 0 +#define MKLDNN_THR_OMP 1 +#define MKLDNN_THR_TBB 2 + +/* Ideally this condition below should never happen (if the library is built + * using regular cmake). For the 3rd-party projects that build the library + * from the sources on their own try to guess the right threading... */ +#if !defined(MKLDNN_THR) +# define MKLDNN_THR MKLDNN_THR_TBB +#endif + +#if MKLDNN_THR == MKLDNN_THR_SEQ +#define MKLDNN_THR_SYNC 1 +inline int mkldnn_get_max_threads() { return 1; } +inline int mkldnn_get_num_threads() { return 1; } +inline int mkldnn_get_thread_num() { return 0; } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() {} + +#define PRAGMA_OMP(...) + +#elif MKLDNN_THR == MKLDNN_THR_OMP +#include +#define MKLDNN_THR_SYNC 1 + +inline int mkldnn_get_max_threads() { return omp_get_max_threads(); } +inline int mkldnn_get_num_threads() { return omp_get_num_threads(); } +inline int mkldnn_get_thread_num() { return omp_get_thread_num(); } +inline int mkldnn_in_parallel() { return omp_in_parallel(); } +inline void mkldnn_thr_barrier() { +# pragma omp barrier +} + +#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) + +#elif MKLDNN_THR == MKLDNN_THR_TBB +#include "tbb/task_arena.h" +#include "tbb/parallel_for.h" +#define MKLDNN_THR_SYNC 0 + +inline int mkldnn_get_max_threads() +{ return tbb::this_task_arena::max_concurrency(); } +inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); } +inline int mkldnn_get_thread_num() +{ return tbb::this_task_arena::current_thread_index(); } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); } + +#define PRAGMA_OMP(...) + +#endif + +/* MSVC still supports omp 2.0 only */ +#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define collapse(x) +# define PRAGMA_OMP_SIMD(...) +#else +# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) +#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER) + +namespace mkldnn { +namespace impl { + +inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; } + +template +inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { + T n_min = 1; + T &n_my = n_end; + if (team <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else if (n_min == 1) { + // team = T1 + T2 + // n = T1*n1 + T2*n2 (n1 - n2 = 1) + T n1 = utils::div_up(n, (T)team); + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_my = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +} // namespace impl +} // namespace mkldnn + +#include "mkldnn_thread_parallel_nd.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp new file mode 100644 index 0000000000..50f9b29622 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP +#define MKLDNN_THREAD_PARALLEL_ND_HPP + +/* This header must be included by mkldnn_thread.hpp only */ + +/* Functions: + * - parallel(nthr, f) - executes f in parallel using at most + * nthr threads. If nthr equals 0 + * mkldnn_get_max_threads() threads is + * used + * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already + * created threads + * - parallel_nd(dims..., f) - creates a parallel section and then + * calls for_nd + * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then + * calls for_nd (mostly for convenience) + */ + +namespace mkldnn { +namespace impl { + +/* general parallelization */ +template +void parallel(int nthr, F f) { + if (nthr == 0) nthr = mkldnn_get_max_threads(); +#if MKLDNN_THR == MKLDNN_THR_SEQ + assert(nthr == 1); + f(0, 1); +#elif MKLDNN_THR == MKLDNN_THR_OMP + if (nthr == 1) { f(0, 1); return; } +# pragma omp parallel num_threads(nthr) + f(mkldnn_get_thread_num(), mkldnn_get_num_threads()); +#elif MKLDNN_THR == MKLDNN_THR_TBB + if (nthr == 1) { f(0, 1); return; } + tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner()); +#endif +} + +/* for_nd section */ + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start{0}, end{0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) f(d0); +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } +} + +template +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +// Skip a lambda function in the parameter pack. +template +constexpr size_t get_work_amount(const T &v) { return 1; } +template +constexpr size_t get_work_amount(const T &v, Args &&...args) +{ return (size_t)v * get_work_amount(utils::forward(args)...); } + +/* parallel_nd and parallel_nd_in_omp section */ + +#if MKLDNN_THR != MKLDNN_THR_TBB +template +void parallel_nd(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + const bool do_parallel = get_work_amount(utils::forward(args)...) > 1; +# pragma omp parallel if (do_parallel) + { + const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads(); + const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num(); + for_nd(ithr, nthr, utils::forward(args)...); + } +#endif +} +#else // MKLDNN_THR != MKLDNN_THR_TBB + +// gcc 4.8 has a bug with passing parameter pack to lambdas. +// So have to explicitly instantiate all the cases. + +template +void parallel_nd(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(T0(iwork)); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } + }, tbb::static_partitioner()); +} + +template +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } + }, tbb::static_partitioner()); +} +#endif + +template +void parallel_nd_in_omp(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(), + utils::forward(args)...); +#elif MKLDNN_THR == MKLDNN_THR_TBB + assert(!"unsupported parallel_nd_in_omp()"); +#endif +} + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp new file mode 100644 index 0000000000..aa671a0b6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TRAITS_HPP +#define MKLDNN_TRAITS_HPP + +#include +#include + +#include "mkldnn.h" +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +template struct prec_traits {}; /* ::type -> float */ +template struct data_traits {}; /* ::data_type -> f32 */ +template struct typesize_traits {}; /* ::data_type_size -> f32 */ +template struct pkind_traits {}; /* ::desc_type, ::query_d */ + +template <> struct prec_traits { typedef float type; }; +template <> struct prec_traits { typedef int32_t type; }; +template <> struct prec_traits { typedef int8_t type; }; +template <> struct prec_traits { typedef uint8_t type; }; + +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::f32; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::s32; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::s8; }; +template <> struct data_traits +{ static constexpr data_type_t data_type = data_type::u8; }; + +template <> struct typesize_traits<4> { typedef float type; }; +template <> struct typesize_traits<2> { typedef int16_t type; }; +template <> struct typesize_traits<1> { typedef uint8_t type; }; + +#define PKIND_TRAITS_INST(op) \ +template <> struct pkind_traits { \ + typedef CONCAT2(op, _desc_t) desc_type; \ + static constexpr query_t query_d = query::CONCAT2(op, _d); \ +} +PKIND_TRAITS_INST(convolution); +PKIND_TRAITS_INST(deconvolution); +PKIND_TRAITS_INST(shuffle); +PKIND_TRAITS_INST(eltwise); +PKIND_TRAITS_INST(softmax); +PKIND_TRAITS_INST(pooling); +PKIND_TRAITS_INST(lrn); +PKIND_TRAITS_INST(batch_normalization); +PKIND_TRAITS_INST(inner_product); +PKIND_TRAITS_INST(rnn); +#undef PKIND_TRAITS_INST + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp new file mode 100644 index 0000000000..f89ea999e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef NSTL_HPP +#define NSTL_HPP + +#include +#include +#include + +#include +#include + +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +void *malloc(size_t size, int alignment); +void free(void *p); + +struct c_compatible { + enum { default_alignment = 64 }; + static void *operator new(size_t sz) { + return malloc(sz, default_alignment); + } + static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; } + static void *operator new[](size_t sz) { + return malloc(sz, default_alignment); + } + static void operator delete(void *p) { free(p); } + static void operator delete[](void *p) { free(p); } +}; + +namespace nstl { + +template +inline const T abs(const T& a) { + return a >= 0 ? a : -a; +} + +template +inline const T& max(const T& a, const T& b) { + return a > b ? a : b; +} + +template +inline const T& min(const T& a, const T& b) { + return a < b ? a : b; +} + +template void swap(T& t1, T& t2) { + T tmp(t1); + t1 = t2; + t2 = tmp; +} + +// Rationale: MKL-DNN needs numeric limits implementation that does not +// generate dependencies on C++ run-time libraries. + +template struct numeric_limits; + +template<> struct numeric_limits { + static constexpr float lowest() { return -FLT_MAX; } + static constexpr float max() { return FLT_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int lowest() { return INT32_MIN; } + static constexpr int max() { return INT32_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int16_t lowest() { return INT16_MIN; } + static constexpr int16_t max() { return INT16_MAX; } +}; + +template<> struct numeric_limits { + static constexpr int8_t lowest() { return INT8_MIN; } + static constexpr int8_t max() { return INT8_MAX; } +}; + +template<> struct numeric_limits { + static constexpr uint8_t lowest() { return 0; } + static constexpr uint8_t max() { return UINT8_MAX; } +}; + +template struct is_integral +{ static constexpr bool value = false; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; +template<> struct is_integral { static constexpr bool value = true; }; + +template struct is_same +{ static constexpr bool value = false; }; +template struct is_same +{ static constexpr bool value = true; }; + +// Rationale: MKL-DNN needs container implementations that do not generate +// dependencies on C++ run-time libraries. +// +// Implementation philosophy: caller is responsible to check if the operation +// is valid. The only functions that have to return status are those that +// depend on memory allocation or similar operations. +// +// This means that e.g. an operator [] does not have to check for boundaries. +// The caller should have checked the boundaries. If it did not we crash and +// burn: this is a bug in MKL-DNN and throwing an exception would not have been +// recoverable. +// +// On the other hand, insert() or resize() or a similar operation needs to +// return a status because the outcome depends on factors external to the +// caller. The situation is probably also not recoverable also, but MKL-DNN +// needs to be nice and report "out of memory" to the users. + +enum nstl_status_t { + success = 0, + out_of_memory +}; + +template class vector: public c_compatible { +private: + std::vector _impl; +public: + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + typedef typename std::vector::size_type size_type; + vector() {} + vector(size_type n): _impl(n) {} + vector(size_type n, const T &value): _impl(n, value) {} + template + vector(input_iterator first, input_iterator last): _impl(first, last) {} + ~vector() {} + size_type size() const { return _impl.size(); } + T& operator[] (size_type i) { return _impl[i]; } + const T& operator[] (size_type i) const { return _impl[i]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template + nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end) + { + _impl.insert(pos, begin, end); + return success; + } + void clear() { _impl.clear(); } + void push_back(const T& t) { _impl.push_back(t); } + void resize(size_type count) { _impl.resize(count); } + void reserve(size_type count) { _impl.reserve(count); } +}; + +template class map: public c_compatible { +private: + std::map _impl; +public: + typedef typename std::map::iterator iterator; + typedef typename std::map::const_iterator const_iterator; + typedef typename std::map::size_type size_type; + map() {} + ~map() {} + size_type size() const { return _impl.size(); } + T& operator[](const Key &k) { return _impl[k]; } + const T& operator[](const Key &k) const { return _impl[k]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template + void clear() { _impl.clear(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp new file mode 100644 index 0000000000..be96e654ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t pooling_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l) + && one_of(alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto pd = pooling_desc_t(); + pd.primitive_kind = primitive_kind::pooling; + pd.prop_kind = prop_kind; + pd.alg_kind = alg_kind; + pd.src_desc.ndims = src_desc->ndims; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + pd.diff_src_desc = pd.src_desc = zero_md(); + pd.diff_dst_desc = pd.dst_desc = zero_md(); + + (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc; + (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(pd.strides, strides, sp_dims); + utils::array_copy(pd.kernel, kernel, sp_dims); + utils::array_copy(pd.padding[0], padding_l, sp_dims); + utils::array_copy(pd.padding[1], padding_r, sp_dims); + + pd.padding_kind = padding_kind; + if (one_of(alg_kind, pooling_max, pooling_avg_include_padding, + pooling_avg_exclude_padding)) { + pd.accum_data_type = types::default_accum_data_type( + src_desc->data_type, dst_desc->data_type); + } else { + pd.accum_data_type = dst_desc->data_type; + } + + bool consistency = true + && utils::one_of(src_desc->ndims, 4, 5) + && utils::one_of(dst_desc->ndims, 4, 5) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == dst_desc->dims[1]; + for (int i = 2; i < src_desc->ndims; ++i) + consistency = consistency && ( + (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2] + + padding_r[i - 2]) / strides[i - 2] + 1 + == dst_desc->dims[i]); + if (!consistency) return invalid_arguments; + + *pool_desc = pd; + return success; +} +} + +status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc, + dst_desc, strides, kernel, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_src_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t kernel, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind, + diff_src_desc, diff_dst_desc, strides, kernel, padding_l, + padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp new file mode 100644 index 0000000000..4c9c009412 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef POOLING_PD_HPP +#define POOLING_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct pooling_fwd_pd_t; + +struct pooling_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::pooling; + + pooling_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , ws_md_() + {} + + const pooling_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::pooling_d: + *(const pooling_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common pooling aux functions */ + + dim_t MB() const { return src_desc().dims[0]; } + dim_t C() const { return src_desc().dims[1]; } + + dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } + dim_t IW() const { return src_desc().dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } + dim_t OW() const { return dst_desc().dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } + dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } + dim_t KW() const { return desc_.kernel[ndims() - 3]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return src_desc().ndims; } + bool is_3d() const { return ndims() == 5; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(src_desc()).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + pooling_desc_t desc_; + const pooling_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; + + void init_default_ws() { + ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); + ws_md_.data_type = indices_data_type(); + } + + data_type_t indices_data_type() const { + /* the simplest way to express 256... */ + const int u8_max = nstl::numeric_limits< + typename prec_traits::type>::max(); + return utils::array_product(desc()->kernel, ndims()) <= u8_max + ? data_type::u8 : data_type::s32; + } + +private: + const memory_desc_t &src_desc() const + { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } + const memory_desc_t &dst_desc() const + { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } +}; + +struct pooling_fwd_pd_t: public pooling_pd_t { + typedef pooling_fwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_fwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } + +protected: + memory_desc_t src_md_; + memory_desc_t dst_md_; + + virtual status_t set_default_params() { + if (dst_md()->format_kind != format_kind::any) + return status::success; + + if (src_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(dst_md_, + src_md_.format_desc.blocking); + } +}; + +struct pooling_bwd_pd_t: public pooling_pd_t { + typedef pooling_bwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_bwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 1 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t diff_dst_md_; + + virtual status_t set_default_params() { + if (diff_src_md()->format_kind != format_kind::any) + return status::success; + + if (diff_dst_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(diff_src_md_, + diff_dst_md_.format_desc.blocking); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp new file mode 100644 index 0000000000..fdf6522f62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "primitive.hpp" +#include "type_helpers.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::primitive_kind; + +namespace { +// XXX: this is a huge hammer. This disables all and any msan checks on +// primitives outputs. +// +// A proper approach would be an implementation-specific unpoisoning. +void unpoison_outputs(const exec_args_t &args) { + for(const auto &arg: args) { + if (arg.second.is_const) continue; + auto *mem = arg.second.mem; + void *p; + mem->get_data_handle(&p); + size_t s = memory_desc_wrapper(*mem->md()).size(); + msan_unpoison(p, s); + } +} +} + +status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) { + if (primitive_desc) delete primitive_desc; + return success; +} + +status_t mkldnn_primitive_create(primitive_t **primitive, + const primitive_desc_t *primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return primitive_desc->create_primitive(primitive); +} + +status_t mkldnn_primitive_execute(const primitive_t *primitive, + stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) { + bool ok = true + && !utils::any_null(primitive, stream) + && primitive->engine() == stream->engine() + && IMPLICATION(nargs > 0, c_args != nullptr); + if (!ok) return invalid_arguments; + + exec_args_t args; + status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args); + if (status != status::success) return status; + + exec_ctx_t ctx(stream, std::move(args)); + + if (mkldnn_verbose()->level) { + double ms = get_msec(); + status = primitive->execute(ctx); + ms = get_msec() - ms; + printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms); + fflush(0); + } else { + status = primitive->execute(ctx); + } + + if (msan_enabled) unpoison_outputs(ctx.args()); + + return status; +} + +status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive, + const primitive_desc_t **primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return safe_ptr_assign(*primitive_desc, + primitive->pd()); +} + +status_t mkldnn_primitive_destroy(primitive_t *primitive) { + if (primitive != nullptr) + delete primitive; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp new file mode 100644 index 0000000000..3b506d6d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_HPP +#define PRIMITIVE_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "primitive_exec_types.hpp" + +/** \brief A pure virtual primitive class + * + * Primitive contains links to its inputs & outputs, though it does not track + * their readiness on execution step. + * + * @remark @b Rational. + * Dependencies are essential through-out the whole MKL-DNN library, so it + * makes sense to include them on the very low level. On the other hand, + * tracking them should be a task for corresponding essence, like scheduler, + * stream or whatever. Primitive itself should know nothing about the + * environment it is running in. + * + * @note + * To make user experience better we should provide API which allows + * achieving the best (or good enough) performance when creating primitives + * in natural order: i.e. from bottom to top for forward pass and from top to + * bottom for backward pass. Please consider restriction [1] in Level 0. + */ +struct mkldnn_primitive: public mkldnn::impl::c_compatible { + mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd) + : pd_(pd->clone()) {} + virtual ~mkldnn_primitive() { delete pd_; } + + /** returns primitive's engine */ + mkldnn::impl::engine_t *engine() const { return pd_->engine(); } + /** returns primitive's inputs */ + const mkldnn::impl::primitive_desc_t *pd() const { return pd_; } + /** returns primitive's kind */ + mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); } + + /** executes primitive with execution context @p ctx */ + virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx) + const = 0; + +protected: + const mkldnn::impl::primitive_desc_t *pd_; + +private: + mkldnn_primitive() = delete; + mkldnn_primitive(const mkldnn_primitive &) = delete; + mkldnn_primitive(mkldnn_primitive &&) = delete; + mkldnn_primitive &operator=(const mkldnn_primitive &) = delete; + mkldnn_primitive &operator=(mkldnn_primitive &&) = delete; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp new file mode 100644 index 0000000000..9fd638842c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +namespace mkldnn { +namespace impl { + +status_t scales_t::set(dim_t count, int mask, const float *scales) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + scales_ = scales_buf_; + utils::array_set(scales_, scales[0], scales_buf_size); + } else { + scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); + if (scales_ == nullptr) + return status::out_of_memory; + + for (dim_t c = 0; c < count_; ++c) + scales_[c] = scales[c]; + } + + return status::success; +} + +} +} + +status_t post_ops_t::append_sum(float scale) { + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::sum; + entry_[len_].sum.scale = scale; + + len_++; + + return success; +} + +status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, + float beta) { + using namespace mkldnn::impl::alg_kind; + bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); + if (!known_alg) + return invalid_arguments; + + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::eltwise; + entry_[len_].eltwise.scale = scale; + entry_[len_].eltwise.alg = alg; + entry_[len_].eltwise.alpha = alpha; + entry_[len_].eltwise.beta = beta; + + len_++; + + return success; +} + +status_t primitive_attr_t::set_scratchpad_mode( + scratchpad_mode_t scratchpad_mode) { + using namespace mkldnn::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, library, user); + if (!ok) + return invalid_arguments; + + scratchpad_mode_ = scratchpad_mode; + return success; +} + +status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { + this->post_ops_ = post_ops; + return success; +} + +/* Public C API */ + +status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { + if (attr == nullptr) + return invalid_arguments; + + return safe_ptr_assign(*attr, + new mkldnn_primitive_attr); +} + +status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, + const primitive_attr_t *existing_attr) { + if (any_null(attr, existing_attr)) + return invalid_arguments; + + return safe_ptr_assign(*attr, + existing_attr->clone()); +} + +status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { + if (attr) + delete attr; + + return success; +} + +status_t mkldnn_primitive_attr_get_scratchpad_mode( + const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { + if (any_null(attr, scratchpad_mode)) + return invalid_arguments; + + *scratchpad_mode = attr->scratchpad_mode_; + + return success; +} + +status_t mkldnn_primitive_attr_set_scratchpad_mode( + primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { + if (any_null(attr)) + return invalid_arguments; + + return attr->set_scratchpad_mode(scratchpad_mode); +} + +status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, + dim_t *count, int *mask, const float **scales) { + if (any_null(attr, count, mask, scales)) + return invalid_arguments; + + *count = attr->output_scales_.count_; + *mask = attr->output_scales_.mask_; + *scales = attr->output_scales_.scales_; + + return success; +} + +status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, + dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_scales_.set(count, mask, scales); +} + +status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, + const post_ops_t **post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + *post_ops = &attr->post_ops_; + return success; +} + +status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, + const post_ops_t *post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + return attr->set_post_ops(*post_ops); +} + +status_t mkldnn_post_ops_create(post_ops_t **post_ops) { + if (post_ops == nullptr) + return invalid_arguments; + + return safe_ptr_assign(*post_ops, new mkldnn_post_ops); +} + +status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { + if (post_ops) + delete post_ops; + + return success; +} + +int mkldnn_post_ops_len(const post_ops_t *post_ops) { + if (post_ops) + return post_ops->len_; + + return 0; +} + +primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, + int index) { + bool ok = post_ops && 0 <= index && index < post_ops->len_; + if (!ok) + return primitive_kind::undefined; + + return post_ops->entry_[index].kind; +} + +status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_sum(scale); +} + +namespace { +bool simple_get_params_check(const post_ops_t *post_ops, int index, + primitive_kind_t kind) { + bool ok = true + && post_ops != nullptr + && 0 <= index + && index < post_ops->len_ + && post_ops->entry_[index].kind == kind; + return ok; +} +} + +status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, + float *scale) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::sum) + && !any_null(scale); + if (!ok) + return invalid_arguments; + + *scale = post_ops->entry_[index].sum.scale; + return success; +} + +status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, + alg_kind_t kind, float alpha, float beta) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_eltwise(scale, kind, alpha, beta); +} + +status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, + int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::eltwise) + && !any_null(scale, alpha, beta); + if (!ok) + return invalid_arguments; + + const auto &e = post_ops->entry_[index].eltwise; + *scale = e.scale; + *alg = e.alg; + *alpha = e.alpha; + *beta = e.beta; + + return success; +} + +status_t mkldnn_primitive_attr_set_rnn_data_qparams( + primitive_attr_t *attr, const float scale, const float shift) { + if (attr == nullptr) + return invalid_arguments; + + return attr->rnn_data_qparams_.set(scale, shift); +} + +status_t mkldnn_primitive_attr_set_rnn_weights_qparams( + primitive_attr_t *attr, dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->rnn_weights_qparams_.set(count, mask, scales); +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp new file mode 100644 index 0000000000..e2130c7ab1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_ATTR_HPP +#define PRIMITIVE_ATTR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_data_qparams_t : public c_compatible { + rnn_data_qparams_t() : scale_(1.), shift_(0.) {} + bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } + + status_t set(float scale, float shift) { + scale_ = scale; + shift_ = shift; + return status::success; + } + + float scale_; + float shift_; +}; + +struct scales_t: public c_compatible { + scales_t(): count_(1), mask_(0), scales_(scales_buf_) + { set(1.); } + + scales_t(const scales_t &rhs): scales_t() + { set(rhs.count_, rhs.mask_, rhs.scales_); } + + ~scales_t() { cleanup(); } + + scales_t &operator=(const scales_t &rhs) { + if (&rhs == this) + return *this; + status_t status = set(rhs.count_, rhs.mask_, rhs.scales_); + assert(status == status::success); + (void)status; + return *this; + } + + bool has_default_values() const { + for (dim_t c = 0; c < count_; ++c) { + if(scales_[c] != 1.) return false; + } + return true; + } + + status_t set(dim_t count, int mask, const float *scales); + status_t set(float single_scale) { return this->set(1, 0, &single_scale); } + + dim_t count_; + int mask_; + float *scales_; + +private: + enum { scales_buf_size = 16 }; + float scales_buf_[scales_buf_size]; + + void cleanup() { + if (scales_ != scales_buf_ && scales_ != nullptr) + impl::free(scales_); + + count_ = 1; + mask_ = 0; + scales_ = scales_buf_; + } +}; + +} +} + +struct mkldnn_post_ops: public mkldnn::impl::c_compatible { + struct entry_t { + struct eltwise_t { + mkldnn::impl::alg_kind_t alg; + float scale, alpha, beta; + }; + + mkldnn::impl::primitive_kind_t kind; + union { + struct { float scale; } sum; + eltwise_t eltwise; + }; + + bool is_eltwise(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::eltwise + && IMPLICATION(require_scale_one, eltwise.scale == 1.f); + } + + bool is_relu(bool require_scale_one = true, + bool require_nslope_zero = true) const { + using namespace mkldnn::impl; + return is_eltwise(require_scale_one) + && eltwise.alg == alg_kind::eltwise_relu + && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); + } + + bool is_sum(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::sum + && IMPLICATION(require_scale_one, sum.scale == 1.f); + } + }; + + mkldnn_post_ops(): len_(0) {} + + mkldnn::impl::status_t append_sum(float scale); + mkldnn::impl::status_t append_eltwise(float scale, + mkldnn::impl::alg_kind_t alg, float alpha, float beta); + + int find(mkldnn::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len_; + stop = mkldnn::impl::nstl::min(stop, len_); + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) return idx; + return -1; + } + + bool has_default_values() const { return len_ == 0; } + + bool contain(mkldnn::impl::primitive_kind_t kind, int index) const + { return find(kind, index, index + 1) == index; } + + enum { capacity = 4 }; + + int len_; + entry_t entry_[capacity]; +}; + +struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible { + mkldnn_primitive_attr() + : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library) + {} + + mkldnn_primitive_attr *clone() const + { return new mkldnn_primitive_attr(*this); } + + /** Returns true if the attributes have default values. + * + * @note The scratchpad_mode_ is not take into account */ + bool has_default_values() const { + return true + && output_scales_.has_default_values() + && post_ops_.has_default_values() + && rnn_data_qparams_.has_default_values() + && rnn_weights_qparams_.has_default_values(); + } + + mkldnn::impl::status_t set_scratchpad_mode( + mkldnn::impl::scratchpad_mode_t scratchpad_mode); + mkldnn::impl::status_t set_post_ops( + const mkldnn::impl::post_ops_t &post_ops); + + mkldnn::impl::scratchpad_mode_t scratchpad_mode_; + mkldnn::impl::scales_t output_scales_; + mkldnn::impl::post_ops_t post_ops_; + mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_; + mkldnn::impl::scales_t rnn_weights_qparams_; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp new file mode 100644 index 0000000000..723c41e05a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t primitive_desc_t::query(query_t what, int idx, void *result) const { + auto safe_ret_md = [&](const memory_desc_t *_) { + if (_ == nullptr) return not_required; + *(const memory_desc_t **)result = _; + return success; + }; + + switch (what) { + case query::engine: *(engine_t**)result = engine(); break; + case query::primitive_kind: *(primitive_kind_t*)result = kind(); break; + + case query::scratchpad_engine: + *(engine_t**)result = scratchpad_engine(); break; + + case query::memory_consumption_s64: + *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break; + + case query::op_d: + if (idx != 0 || op_desc() == nullptr) return invalid_arguments; + *(const_c_op_desc_t *)result + = static_cast(op_desc()); break; + + case query::src_md: return safe_ret_md(src_md(idx)); + case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); + case query::dst_md: return safe_ret_md(dst_md(idx)); + case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); + case query::weights_md: return safe_ret_md(weights_md(idx)); + case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx)); + case query::workspace_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(workspace_md(idx)); + case query::scratchpad_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(scratchpad_md(idx)); + + case query::num_of_inputs_s32: *(int*)result = n_inputs(); break; + case query::num_of_outputs_s32: *(int*)result = n_outputs(); break; + + case query::impl_info_str: *(const char **)result = name(); break; + + default: return unimplemented; + } + return success; +} + +status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc, + const primitive_attr_t **attr) { + if (utils::any_null(primitive_desc, attr)) + return invalid_arguments; + + *attr = primitive_desc->attr(); + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp new file mode 100644 index 0000000000..536dcfa1d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp @@ -0,0 +1,174 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_DESC_HPP +#define PRIMITIVE_DESC_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "primitive_attr.hpp" +#include "verbose.hpp" + +struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { + using md_t = mkldnn::impl::memory_desc_t; + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), kind_(kind) { info_[0] = '\0'; } + + virtual mkldnn_primitive_desc *clone() const = 0; + virtual ~mkldnn_primitive_desc() {} + + const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } + mkldnn::impl::engine_t *engine() const { return engine_; } + mkldnn::impl::primitive_kind_t kind() const { return kind_; } + + virtual void init_info() {} + const char *info() const { return info_; } + + mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() + { return scratchpad_registry_; } + const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const + { return scratchpad_registry_; } + virtual mkldnn::impl::engine_t *scratchpad_engine() const + { return engine_; } + + virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } + + enum class arg_usage_t { unused, input, output }; + virtual arg_usage_t arg_usage( + mkldnn::impl::primitive_arg_index_t arg) const { + using mkldnn::impl::types::is_zero_md; + if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) + return arg_usage_t::output; + return arg_usage_t::unused; + } + +# define DECLARE_MD_STUB(stub) \ + virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ + { return nullptr; } + + DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); + DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); + DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); + DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); + DECLARE_MD_STUB(workspace_md); +# undef DECLARE_MD_STUB + + const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { + return idx == 0 ? &scratchpad_md_ : nullptr; + } + + virtual void init_scratchpad_md() { + auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); + mkldnn::impl::dims_t dims = { size }; + mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, + mkldnn::impl::data_type::u8, mkldnn_x); + } + + /** returns the scratchpad size for the given scratchpad mode. */ + mkldnn::impl::dim_t scratchpad_size( + mkldnn::impl::scratchpad_mode_t mode) const { + if (mode != attr_.scratchpad_mode_) return 0; + return scratchpad_registry().size(); + } + + virtual int n_inputs() const { return 0; } + virtual int n_outputs() const { return 0; } + + virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, + void *result) const; + + virtual mkldnn::impl::status_t create_primitive( + mkldnn::impl::primitive_t **primitive) const = 0; + + virtual const char *name() const { return "mkldnn_primitive_desc"; } + + /* static magic */ + + template + static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, + const mkldnn::impl::op_desc_t *adesc, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_desc_t *hint_fwd) { + using namespace mkldnn::impl; + using namespace mkldnn::impl::status; + using pd_op_desc_t = typename pkind_traits::desc_type; + if (adesc->kind != pd_t::base_pkind) return invalid_arguments; + assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); + auto hint = + reinterpret_cast(hint_fwd); + auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->init_info(); + _pd->init_scratchpad_md(); + *pd = _pd; + return success; + } + +protected: + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_attr_t attr_; + mkldnn::impl::primitive_kind_t kind_; + + mkldnn::impl::memory_desc_t scratchpad_md_; + + char info_[MKLDNN_VERBOSE_BUF_LEN]; + + mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; + +protected: + /** compares ws between fwd_pd and this (make sense to use for bwd_pd) + * Expectation: this already set workspace, and this workspace should + * exactly match the one from fwd_pd */ + bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { + using namespace mkldnn::impl; + if (!workspace_md()) return true; // the impl lives fine w/o workspace + return fwd_pd && fwd_pd->workspace_md() + && *fwd_pd->workspace_md() == *workspace_md(); + } +}; + +#define DECLARE_COMMON_PD_t(impl_name, ...) \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual const char *name() const override { return impl_name; } +#define DECLARE_COMMON_PD_T(impl_name, ...) \ + DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp new file mode 100644 index 0000000000..43e5a31ef3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "memory.hpp" +#include "primitive.hpp" +#include "primitive_exec_types.hpp" + +namespace mkldnn { +namespace impl { + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args) { + using namespace status; + + if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; + + int n_inputs = 0; + int n_outputs = 0; + + for (int i = 0; i < nargs; ++i) { + primitive_arg_index_t arg = c_args[i].arg; + auto *mem = c_args[i].memory; + + switch (pd->arg_usage(arg)) { + case primitive_desc_t::arg_usage_t::input: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, true}; + n_inputs++; + break; + case primitive_desc_t::arg_usage_t::output: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, false}; + n_outputs++; + break; + case primitive_desc_t::arg_usage_t::unused: + break; + } + } + + bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md()); + + if (n_inputs != pd->n_inputs()) return invalid_arguments; + if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0)) + return invalid_arguments; + + return success; +} + +const void *exec_ctx_t::input(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +void *exec_ctx_t::output(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(!ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const { + assert(args_.count(arg) == 1); + const auto ma = args_.at(arg); + assert(!ma.is_const); + return ma.mem; +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp new file mode 100644 index 0000000000..0645891da7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_EXEC_TYPES_HPP +#define PRIMITIVE_EXEC_TYPES_HPP + +#include + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct memory_arg_t { + memory_t *mem; + bool is_const; +}; + +using exec_args_t = std::unordered_map; + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args); + +/** Primitive execution context (helps passing stream, memories, and events. */ +struct exec_ctx_t { + exec_ctx_t(const exec_ctx_t &) = default; + exec_ctx_t(exec_ctx_t &&) = default; + + exec_ctx_t(stream_t *stream): stream_(stream) {} + exec_ctx_t(stream_t *stream, exec_args_t &&args) + : stream_(stream) + , args_(std::move(args)) {} + + stream_t *stream() const { return stream_; } + const exec_args_t &args() const { return args_; } + + /* tentative solution... TODO: replace with functions return memory_t */ + const void *input(primitive_arg_index_t arg) const; + void *output(primitive_arg_index_t arg) const; + + const memory_t *memory(primitive_arg_index_t arg) const; + +private: + stream_t *stream_; + exec_args_t args_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp new file mode 100644 index 0000000000..5a1cd7d379 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_iterator_create( + primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc, + const primitive_attr_t *attr, engine_t *engine, + const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd); + if (it == nullptr) return out_of_memory; + + ++(*it); + if (*it == it->end()) { + delete it; + return unimplemented; + } + + *iterator = it; + return success; +} + +status_t mkldnn_primitive_desc_iterator_next( + primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return invalid_arguments; + ++(*iterator); + return *iterator == iterator->end() ? iterator_ends : success; +} + +primitive_desc_t *mkldnn_primitive_desc_iterator_fetch( + const primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return nullptr; + return *(*iterator); +} + +status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc, + const primitive_desc_t *existing_primitive_desc) { + if (utils::any_null(primitive_desc, existing_primitive_desc)) + return invalid_arguments; + return safe_ptr_assign(*primitive_desc, + existing_primitive_desc->clone()); +} + +status_t mkldnn_primitive_desc_iterator_destroy( + primitive_desc_iterator_t *iterator) { + if (iterator != nullptr) + delete iterator; + return success; +} + +status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc, + const_c_op_desc_t c_op_desc, const primitive_attr_t *attr, + engine_t *engine, const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd); + ++it; + if (it == it.end()) return unimplemented; + + return safe_ptr_assign(*primitive_desc, *it); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp new file mode 100644 index 0000000000..4e88ab3aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#ifndef PRIMITIVE_ITERATOR_HPP +#define PRIMITIVE_ITERATOR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { + using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, + const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) + : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) + , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) + , impl_list_(engine_->get_implementation_list()), last_idx_(0) + { + while (impl_list_[last_idx_] != nullptr) ++last_idx_; + } + ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } + + bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } + bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return !operator==(rhs); } + + mkldnn::impl::primitive_desc_iterator_t end() const + { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } + + mkldnn::impl::primitive_desc_iterator_t &operator++() { + if (pd_) { delete pd_; pd_ = nullptr; } + while (++idx_ != last_idx_) { + auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, + hint_fwd_pd_); + if (s == mkldnn::impl::status::success) break; + } + return *this; + } + + mkldnn::impl::primitive_desc_t *operator*() const { + if (*this == end() || pd_ == nullptr) return nullptr; + return pd_->clone(); + } + +protected: + int idx_; + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_desc_t *pd_; + const mkldnn::impl::op_desc_t *op_desc_; + const mkldnn::impl::primitive_attr_t attr_; + const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; + const pd_create_f *impl_list_; + int last_idx_; + +private: + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) + : idx_(last_idx), engine_(engine), pd_(nullptr) + , op_desc_(nullptr), hint_fwd_pd_(nullptr) + , impl_list_(nullptr), last_idx_(last_idx) {} +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp new file mode 100644 index 0000000000..835cd73581 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc, + query_t what, int index, void *result) { + if (any_null(primitive_desc, result)) + return invalid_arguments; + + return primitive_desc->query(what, index, result); +} + +const memory_desc_t *mkldnn_primitive_desc_query_md( + const primitive_desc_t *primitive_desc, query_t what, int index) { + const memory_desc_t *res_md = nullptr; + bool args_ok = true + && primitive_desc != nullptr + && (what & query::some_md) == query::some_md + && what != query::some_md + && mkldnn_primitive_desc_query(primitive_desc, + what, index, &res_md) == success; + return args_ok ? res_md : nullptr; +} + +int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc, + query_t what, int index) { + int res_s32; + bool args_ok = primitive_desc != nullptr + && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32) + && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32) + == success; + return args_ok ? res_s32 : 0; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp new file mode 100644 index 0000000000..d11f1a0361 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "reorder_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_reorder_primitive_desc_create( + primitive_desc_t **reorder_pd, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md, + const primitive_attr_t *attr) { + if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md)) + return invalid_arguments; + + auto s_ek = src_engine->kind(); + auto d_ek = dst_engine->kind(); + if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek))) + return invalid_arguments; + + auto r_pd = reinterpret_cast(reorder_pd); + auto s_mdw = memory_desc_wrapper(*src_md); + auto d_mdw = memory_desc_wrapper(*dst_md); + + if (!s_mdw.consistent_with(d_mdw)) + return invalid_arguments; + + auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + for (auto r = e->get_reorder_implementation_list(); *r; ++r) { + if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md) + == success) { + (*r_pd)->init_info(); + (*r_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp new file mode 100644 index 0000000000..963cb0f58a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REORDER_PD_HPP +#define REORDER_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct reorder_pd_t: public primitive_desc_t { + reorder_pd_t(engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) + : primitive_desc_t(engine, attr, primitive_kind::reorder) + , src_engine_(src_engine) + , dst_engine_(dst_engine) + , scratchpad_engine_(nullptr) + , src_md_(*src_md) + , dst_md_(*dst_md) + {} + + virtual const op_desc_t *op_desc() const override { return nullptr; } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_FROM) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_TO) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + float alpha() const { return attr()->output_scales_.scales_[0]; } + float beta() const { + const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); + return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; + } + virtual mkldnn::impl::engine_t *scratchpad_engine() const override + { return scratchpad_engine_; } + +protected: + engine_t *src_engine_; + engine_t *dst_engine_; + engine_t *scratchpad_engine_; + + memory_desc_t src_md_; + memory_desc_t dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp new file mode 100644 index 0000000000..36967431a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu/gemm/os_blas.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::utils; + +namespace { +memory_desc_t copy_maybe_null(const memory_desc_t *md) { + return md ? *md : zero_md(); +} + +rnn_desc_t zero_rnn_desc() { + auto rd = rnn_desc_t(); + rd.src_layer_desc = zero_md(); + rd.src_iter_desc = zero_md(); + rd.weights_layer_desc = zero_md(); + rd.weights_iter_desc = zero_md(); + rd.bias_desc = zero_md(); + rd.dst_layer_desc = zero_md(); + rd.dst_iter_desc = zero_md(); + rd.diff_src_layer_desc = zero_md(); + rd.diff_src_iter_desc = zero_md(); + rd.diff_weights_layer_desc = zero_md(); + rd.diff_weights_iter_desc = zero_md(); + rd.diff_bias_desc = zero_md(); + rd.diff_dst_layer_desc = zero_md(); + rd.diff_dst_iter_desc = zero_md(); + return rd; +} +} + +/* Public C Api */ + +status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, + unsigned int flags, float alpha, float clipping) { + using namespace mkldnn::impl::alg_kind; + + bool args_ok = true + && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, + gru_linear_before_reset) + && IMPLICATION(cell_kind == vanilla_rnn, + one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); + if (!args_ok) + return invalid_arguments; + + auto rcd = mkldnn_rnn_cell_desc_t(); + + rcd.cell_kind = cell_kind; + rcd.activation_kind = act_f; + rcd.flags = flags; + rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; + rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; + + *rnn_cell_desc = rcd; + + return success; +} + +int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 3; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; + case mkldnn::impl::alg_kind::vanilla_lstm: return 4; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 1; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; + case mkldnn::impl::alg_kind::vanilla_lstm: return 2; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, + prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + using namespace data_type; + data_type_t src_layer_dt = src_layer_desc->data_type; + data_type_t dst_layer_dt = dst_layer_desc->data_type; + data_type_t weights_iter_dt = weights_iter_desc->data_type; + data_type_t weights_layer_dt = weights_layer_desc->data_type; + + bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, + weights_layer_dt) + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + +#if USE_MKL_PACKED_GEMM + bool is_u8u8u8 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == u8) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == u8) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_f32u8f32 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_inference = prop_kind == prop_kind::forward_inference; + bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; + + return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) + ? success + : unimplemented; +#else + return is_f32 ? success : unimplemented; +#endif +} + +status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, + rnn_direction_t direction, int L, int D, int T, int N, int S, int G, + int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok; + + // * algorithm specific + args_ok = true + && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, + DIC == SIC); + if (!args_ok) return invalid_arguments; + int extra_bias = + rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; + + // * on num layers + args_ok = true + && L == weights_layer_desc->dims[0] + && L == weights_iter_desc->dims[0] + && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) + && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); + if (!args_ok) return invalid_arguments; + + // * on num directions + args_ok = true + && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) + && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) + && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); + if (!args_ok) return invalid_arguments; + + // * on num iterations + args_ok = true + && T == src_layer_desc->dims[0] + && T == dst_layer_desc->dims[0]; + if (!args_ok) return invalid_arguments; + + // * on mb + args_ok = true + && N == src_layer_desc->dims[1] + && N == dst_layer_desc->dims[1] + && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); + if (!args_ok) return invalid_arguments; + + // * on num gates + args_ok = true + && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) + && G == weights_layer_desc->dims[3] + && G == weights_iter_desc->dims[3] + && IMPLICATION(!is_zero_md(bias_desc), + G + extra_bias == bias_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on num states + args_ok = true + && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) + && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) + && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on slc + args_ok = true + && SLC == weights_layer_desc->dims[2] + && SLC == src_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on sic + args_ok = true + && SIC == weights_iter_desc->dims[2] + && IMPLICATION(!is_zero_md(src_iter_desc), + SIC == src_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * on dlc + int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; + args_ok = true + && DLC == dlc_multiplier * DIC + && DLC == dst_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on dic + args_ok = true + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), + DIC == dst_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * unrolling/fusion conditions + args_ok = true + && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) + && IMPLICATION(T > 1, SIC == DIC); + if (!args_ok) return invalid_arguments; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok = true && rnn_cell_desc != nullptr + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc); + if (!args_ok) return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc)); + + CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, + src_layer_desc, src_iter_desc, weights_layer_desc, + weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); + + // Create the descriptor + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + + *rnn_desc = rd; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, + const memory_desc_t *diff_src_layer_desc, + const memory_desc_t *diff_src_iter_desc, + const memory_desc_t *diff_weights_layer_desc, + const memory_desc_t *diff_weights_iter_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_layer_desc, + const memory_desc_t *diff_dst_iter_desc) { + bool args_ok = true + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc, diff_src_layer_desc, + diff_weights_layer_desc, diff_weights_iter_desc, + diff_dst_layer_desc); + if (!args_ok) + return invalid_arguments; + + auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { + return is_zero_md(a_md) == is_zero_md(b_md); + }; + + args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) + && xnor_md(dst_iter_desc, diff_dst_iter_desc) + && xnor_md(src_iter_desc, diff_src_iter_desc); + if (!args_ok) + return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc); + if (st != success) return st; + + st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, + diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc); + if (st != success) return st; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); + rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); + rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); + rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); + rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); + rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); + rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); + + *rnn_desc = rd; + + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp new file mode 100644 index 0000000000..1ee2ba1114 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_PD_HPP +#define RNN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_fwd_pd_t; + +struct rnn_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::rnn; + + rnn_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , src_layer_md_(desc_.src_layer_desc) + , src_iter_md_(desc_.src_iter_desc) + , weights_layer_md_(desc_.weights_layer_desc) + , weights_iter_md_(desc_.weights_iter_desc) + , bias_md_(desc_.bias_desc) + , dst_layer_md_(desc_.dst_layer_desc) + , dst_iter_md_(desc_.dst_iter_desc) + , ws_md_() + {} + + const rnn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &src_layer_md_; + if (index == 1 && with_src_iter()) return &src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_layer_md_; + if (index == 1) return &weights_iter_md_; + if (index == 2 && with_bias()) return &bias_md_; + return nullptr; + } + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &dst_layer_md_; + if (index == 1 && with_dst_iter()) return &dst_iter_md_; + return nullptr; + } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + /* common pooling aux functions */ + + bool is_training() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::backward); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + dim_t T() const { return desc_.src_layer_desc.dims[0]; } + dim_t MB() const { return desc_.src_layer_desc.dims[1]; } + + dim_t L() const { return desc_.weights_layer_desc.dims[0]; } + dim_t D() const { return desc_.weights_layer_desc.dims[1]; } + + dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } + + dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } + dim_t G() const { return desc_.weights_layer_desc.dims[3]; } + dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; } + + dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } + + bool with_bias() const + { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); } + + bool with_src_iter() const + { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); } + + bool with_dst_iter() const + { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); } + + mkldnn::impl::alg_kind_t cell_kind() const + { return desc_.cell_desc.cell_kind; } + mkldnn::impl::alg_kind_t activation_kind() const + { return desc_.cell_desc.activation_kind; } + + bool is_lbr() const + { return cell_kind() == mkldnn_gru_linear_before_reset; } + + mkldnn_rnn_direction_t direction() const { return desc_.direction; } + +protected: + rnn_desc_t desc_; + const rnn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t src_layer_md_; + memory_desc_t src_iter_md_; + memory_desc_t weights_layer_md_; + memory_desc_t weights_iter_md_; + memory_desc_t bias_md_; + memory_desc_t dst_layer_md_; + memory_desc_t dst_iter_md_; + + memory_desc_t ws_md_; +}; + +struct rnn_fwd_pd_t: public rnn_pd_t { + typedef rnn_fwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_fwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC_LAYER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter()) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST_LAYER) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter()) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual int n_inputs() const override + { return 3 + with_bias() + with_src_iter(); } + virtual int n_outputs() const override + { return 1 + with_dst_iter() + is_training(); } +}; + +struct rnn_bwd_pd_t : public rnn_pd_t { + typedef rnn_bwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_bwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_layer_md_(desc_.diff_src_layer_desc) + , diff_src_iter_md_(desc_.diff_src_iter_desc) + , diff_weights_layer_md_(desc_.diff_weights_layer_desc) + , diff_weights_iter_md_(desc_.diff_weights_iter_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_layer_md_(desc_.diff_dst_layer_desc) + , diff_dst_iter_md_(desc_.diff_dst_iter_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER, + MKLDNN_ARG_DIFF_DST_LAYER)) + return arg_usage_t::input; + + if (with_src_iter()) { + if (arg == MKLDNN_ARG_SRC_ITER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC_ITER) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (with_bias()) { + if (arg == MKLDNN_ARG_BIAS) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_BIAS) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER) + && with_dst_iter()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_ITER)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override { + if (index == 0) return &diff_src_layer_md_; + if (index == 1 && with_src_iter()) return &diff_src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *diff_weights_md( + int index = 0) const override { + if (index == 0) return &diff_weights_layer_md_; + if (index == 1) return &diff_weights_iter_md_; + if (index == 2 && with_bias()) return &diff_bias_md_; + return nullptr; + } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override { + if (index == 0) return &diff_dst_layer_md_; + if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; + return nullptr; + } + + virtual int n_inputs() const override + { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); } + virtual int n_outputs() const override + { return 3 + with_src_iter() + with_bias(); } + +protected: + memory_desc_t diff_src_layer_md_; + memory_desc_t diff_src_iter_md_; + memory_desc_t diff_weights_layer_md_; + memory_desc_t diff_weights_iter_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_layer_md_; + memory_desc_t diff_dst_iter_md_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp new file mode 100644 index 0000000000..6bc14fc72a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "scratchpad.hpp" + +namespace mkldnn { +namespace impl { + +/* Allocating memory buffers on a page boundary to reduce TLB/page misses */ +const size_t page_size = 2097152; + +/* + Implementation of the scratchpad_t interface that is compatible with + a concurrent execution +*/ +struct concurent_scratchpad_t : public scratchpad_t { + concurent_scratchpad_t(size_t size) { + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + + ~concurent_scratchpad_t() { + free(scratchpad_); + } + + virtual char *get() const { + return scratchpad_; + } + +private: + char *scratchpad_; + size_t size_; +}; + +/* + Implementation of the scratchpad_t interface that uses a global + scratchpad +*/ + +struct global_scratchpad_t : public scratchpad_t { + global_scratchpad_t(size_t size) { + if (size > size_) { + if (scratchpad_ != nullptr) free(scratchpad_); + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + reference_count_++; + } + + ~global_scratchpad_t() { + reference_count_--; + if (reference_count_ == 0) { + free(scratchpad_); + scratchpad_ = nullptr; + size_ = 0; + } + } + + virtual char *get() const { + return scratchpad_; + } + +private: + /* + Using thread-local here is unnecessary and even buggy! All threads + actually share the same scratchpad, which is created and queried only + on the main thread. If the scratchpad is queried on some thread other + than the one it was created on (e.g. the application calls the API from + multiple threads), thread-local causes a segfault because the scratchpad + is uninitialized on the current thread. + */ + /*thread_local*/ static char *scratchpad_; + /*thread_local*/ static size_t size_; + /*thread_local*/ static unsigned int reference_count_; +}; + +/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr; +/*thread_local*/ size_t global_scratchpad_t::size_ = 0; +/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0; + + +/* + Scratchpad creation routine +*/ +scratchpad_t *create_scratchpad(size_t size) { +#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC + return new global_scratchpad_t(size); +#else + return new concurent_scratchpad_t(size); +#endif +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp new file mode 100644 index 0000000000..f7a246bc99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_SCRATCHPAD_HPP +#define COMMON_SCRATCHPAD_HPP + +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct scratchpad_t { + virtual ~scratchpad_t() {} + virtual char *get() const = 0; +}; + +scratchpad_t *create_scratchpad(size_t size); + +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp new file mode 100644 index 0000000000..e32e735224 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, int axis, dim_t group_size) { + bool args_ok = true + && !any_null(shuffle_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward, backward_data) + && axis >= 0 && axis < data_desc->ndims + && group_size > 0 && group_size <= data_desc->dims[axis]; + if (!args_ok) return invalid_arguments; + + auto sd = shuffle_desc_t(); + sd.primitive_kind = primitive_kind::shuffle; + sd.prop_kind = prop_kind; + sd.data_desc = *data_desc; + sd.axis = axis; + sd.group_size = group_size; + + bool consistency = true + && sd.data_desc.dims[axis] % sd.group_size == 0; + if (!consistency) return invalid_arguments; + + *shuffle_desc = sd; + return success; +} +} + +status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis, + dim_t group_size) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis, + group_size); +} + +status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc, + const memory_desc_t *diff_data_desc, int axis, dim_t group_size) { + return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis, + group_size); +} + +// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp new file mode 100644 index 0000000000..cc5553fe7f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SHUFFLE_PD_HPP +#define SHUFFLE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct shuffle_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::shuffle; + + typedef shuffle_pd_t base_class; + typedef shuffle_pd_t hint_class; + + shuffle_pd_t(engine_t *engine, + const shuffle_desc_t *adesc, + const primitive_attr_t *attr, + const shuffle_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const shuffle_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::shuffle_d: + *(const shuffle_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (is_fwd()) { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + } else { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + } + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + /* shuffle aux functions */ + + dim_t MB() const { return data_md()->dims[0]; } + dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } + dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } + + int ndims() const { return data_md()->ndims; } + + int axis() const { return desc_.axis; } + dim_t group_size() const { return desc_.group_size; } + dim_t axis_size() const { return data_md()->dims[axis()]; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + const memory_desc_t *data_md() const { return &data_md_; } + +protected: + shuffle_desc_t desc_; + const shuffle_pd_t *hint_fwd_pd_; + memory_desc_t data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp new file mode 100644 index 0000000000..82848e3d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) { + bool args_ok = true + && !any_null(softmax_desc, data_desc) + && 0 <= softmax_axis + && softmax_axis < data_desc->ndims; + if (!args_ok) return invalid_arguments; + + auto sd = softmax_desc_t(); + sd.primitive_kind = primitive_kind::softmax; + sd.prop_kind = prop_kind; + + bool is_bwd = (sd.prop_kind == backward_data); + sd.data_desc = *data_desc; + sd.diff_desc = is_bwd ? *diff_desc : zero_md(); + sd.softmax_axis = softmax_axis; + + *softmax_desc = sd; + return success; +} +} + +status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + int softmax_axis) { + if (!one_of(prop_kind, forward_inference, forward_training)) + return invalid_arguments; + return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis); +} + +status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc, + const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, + int softmax_axis) { + return softmax_desc_init(softmax_desc, prop_kind::backward_data, + data_desc, diff_desc, softmax_axis); +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp new file mode 100644 index 0000000000..8a16ce901c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SOFTMAX_PD_HPP +#define SOFTMAX_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct softmax_fwd_pd_t; + +struct softmax_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::softmax; + + softmax_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const softmax_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::softmax_d: + *(const softmax_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common softmax aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + softmax_desc_t desc_; + const softmax_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct softmax_fwd_pd_t: public softmax_pd_t { + typedef softmax_fwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_fwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct softmax_bwd_pd_t: public softmax_pd_t { + typedef softmax_bwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_bwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp new file mode 100644 index 0000000000..00af8935c0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +/* API */ + +status_t mkldnn_stream_create(stream_t **stream, engine_t *engine, + unsigned flags) { + bool args_ok = true + && !utils::any_null(stream, engine) + && flags == stream_flags::default_flags; + if (!args_ok) + return invalid_arguments; + + return safe_ptr_assign(*stream, new stream_t(engine, flags)); +} + +status_t mkldnn_stream_destroy(stream_t *stream) { + delete stream; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp new file mode 100644 index 0000000000..f010e5f6ed --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef STREAM_HPP +#define STREAM_HPP + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" + +struct mkldnn_stream: public mkldnn::impl::c_compatible { + mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags) + : engine_(engine), flags_(flags) {} + virtual ~mkldnn_stream() {} + + /** returns stream's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + + /** returns stream's kind */ + unsigned flags() const { return flags_; } + +protected: + mkldnn::impl::engine_t *engine_; + unsigned flags_; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp new file mode 100644 index 0000000000..365663c0f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "sum_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds, const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != dims[d]) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto s_pd = reinterpret_cast(sum_pd); + + for (auto s = engine->get_sum_implementation_list(); *s; ++s) { + if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) { + (*s_pd)->init_info(); + (*s_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp new file mode 100644 index 0000000000..80254667df --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SUM_PD_HPP +#define SUM_PD_HPP + +#include +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct sum_pd_t: public primitive_desc_t { + sum_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::sum) + , n_(n), dst_md_(*dst_md) + { + scales_.reserve(n_); + for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]); + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + const float *scales() const { return &scales_[0]; } + +protected: + int n_; + nstl::vector scales_; + memory_desc_t dst_md_; + nstl::vector src_mds_; + +protected: + /* inits dst_md_ in simple cases. The call may fail. */ + status_t init() { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(&src_mds_[i]); + if (!src_d.is_blocking_desc() || src_d.is_additional_buffer()) + return status::unimplemented; + } + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain, pick the format of the first input + */ + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (!src_d.is_plain() && src_d.is_blocking_desc()) { + return memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + } + } + + if (src_mds_[0].format_kind != format_kind::blocked) + return status::unimplemented; + + dst_md_ = src_mds_[0]; + + return status::success; + } +}; + +#define DECLARE_SUM_PD_t(impl_name, ...) \ + static status_t create(sum_pd_t **sum_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, const float *scales, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign(*sum_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_SUM_PD_T(impl_name, ...) \ + DECLARE_SUM_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp new file mode 100644 index 0000000000..a408f45980 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TAG_TRAITS_HPP +#define TAG_TRAITS_HPP + +#include + +#include "c_types_map.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +enum class block_dim_t { + _, + _A, _B, + _AB, _BC, +}; + +enum class inner_blk_t { + _, + _4a, _4b, + _8a, _8b, + _16a, _16b, + + _4b4a, _4b4c, _4c4b, + _8a8b, _8b8a, _8b8c, _8c8b, + _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b, + + _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c, +}; + +/** returns the offset within the block for weights blocked over oc and ic */ +template +constexpr int AB_or_BC_blk_off(int x0, int x1) { + using ib = inner_blk_t; + static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b, + ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b, + ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c), + "unexpected inner_blk format"); + return false ? 0 + : (f == ib::_4b4c) ? 4 * x0 + x1 + : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 + : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 + : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 + : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 + : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 + : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 + : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 + : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 + : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 + : INT_MIN; +} + +template struct inner_blk_traits { + using ib = inner_blk_t; +}; + +template struct tag_traits { + // block_dim_t block_dims; + // inner_blk_t inner_blks; + // int ndims; +}; + +#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ +template <> struct tag_traits { \ + static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ + static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ + static constexpr int ndims = _ndims; \ +} + +DECL_TRAITS(a, _, _, 1); +DECL_TRAITS(ab, _, _, 2); +DECL_TRAITS(abc, _, _, 3); +DECL_TRAITS(abcd, _, _, 4); +DECL_TRAITS(abcde, _, _, 5); +DECL_TRAITS(abcdef, _, _, 6); +DECL_TRAITS(abdec, _, _, 5); +DECL_TRAITS(acb, _, _, 3); +DECL_TRAITS(acbde, _, _, 5); +DECL_TRAITS(acdb, _, _, 4); +DECL_TRAITS(acdeb, _, _, 5); +DECL_TRAITS(ba, _, _, 2); +DECL_TRAITS(bac, _, _, 3); +DECL_TRAITS(bacd, _, _, 4); +DECL_TRAITS(bcda, _, _, 4); +DECL_TRAITS(cba, _, _, 3); +DECL_TRAITS(cdba, _, _, 4); +DECL_TRAITS(cdeba, _, _, 5); +DECL_TRAITS(decab, _, _, 5); + +DECL_TRAITS(Abc4a, _A, _4a, 3); +DECL_TRAITS(aBc4b, _B, _4b, 3); +DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3); +DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3); +DECL_TRAITS(Abcd4a, _A, _4a, 4); +DECL_TRAITS(aBcd4b, _B, _4b, 4); +DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4); +DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4); +DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4); +DECL_TRAITS(Abcde4a, _A, _4a, 5); +DECL_TRAITS(aBcde4b, _B, _4b, 5); +DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5); +DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5); +DECL_TRAITS(aBcdef4b, _B, _4b, 6); +DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6); +DECL_TRAITS(aBdc4b, _B, _4b, 4); +DECL_TRAITS(aBdec4b, _B, _4b, 5); +DECL_TRAITS(aBdefc4b, _B, _4b, 6); +DECL_TRAITS(Acb4a, _A, _4a, 3); +DECL_TRAITS(Acdb4a, _A, _4a, 4); +DECL_TRAITS(Acdeb4a, _A, _4a, 5); + +DECL_TRAITS(Abc16a, _A, _16a, 3); +DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(aBc16b, _B, _16b, 3); +DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3); +DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3); +DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3); +DECL_TRAITS(aBc8b, _B, _8b, 3); +DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3); +DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3); +DECL_TRAITS(Abcd16a, _A, _16a, 4); +DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4); +DECL_TRAITS(aBcd16b, _B, _16b, 4); +DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4); +DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4); +DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4); +DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); +DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(aBcd8b, _B, _8b, 4); +DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); +DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4); +DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4); +DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4); +DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4); +DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4); +DECL_TRAITS(Abcde16a, _A, _16a, 5); +DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5); +DECL_TRAITS(aBcde16b, _B, _16b, 5); +DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5); +DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5); +DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5); +DECL_TRAITS(Abcde8a, _A, _8a, 5); +DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5); +DECL_TRAITS(aBcde8b, _B, _8b, 5); +DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5); +DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5); +DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5); +DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5); +DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5); +DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5); +DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5); +DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5); +DECL_TRAITS(aBcdef16b, _B, _16b, 6); +DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6); +DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6); +DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6); +DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6); +DECL_TRAITS(aBdc16b, _B, _16b, 4); +DECL_TRAITS(aBdc8b, _B, _8b, 4); +DECL_TRAITS(aBdec16b, _B, _16b, 5); +DECL_TRAITS(aBdec8b, _B, _8b, 5); +DECL_TRAITS(aBdefc16b, _B, _16b, 6); +DECL_TRAITS(aBdefc8b, _B, _8b, 6); +DECL_TRAITS(Acb16a, _A, _16a, 3); +DECL_TRAITS(Acb8a, _A, _8a, 3); +DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(Acdb16a, _A, _16a, 4); +DECL_TRAITS(Acdb8a, _A, _8a, 4); +DECL_TRAITS(Acdeb16a, _A, _16a, 5); +DECL_TRAITS(Acdeb8a, _A, _8a, 5); +DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp new file mode 100644 index 0000000000..4f06368738 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_HELPERS_HPP +#define TYPE_HELPERS_HPP + +#include +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "mkldnn_traits.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "math_utils.hpp" + +namespace mkldnn { +namespace impl { + +template +status_t safe_ptr_assign(T * &lhs, T* rhs) { + if (rhs == nullptr) return status::out_of_memory; + lhs = rhs; + return status::success; +} + +template struct is_subset +{ static constexpr bool value = false; }; +template struct is_subset +{ static constexpr bool value = true; }; +template struct is_subset::value, float>::type> +{ static constexpr bool value = true; }; +#define ISSPEC(t1, t2) template <> \ + struct is_subset { static constexpr bool value = true; } +ISSPEC(int16_t, int32_t); +ISSPEC(int8_t, int32_t); +ISSPEC(uint8_t, int32_t); +ISSPEC(int8_t, int16_t); +ISSPEC(uint8_t, int16_t); +#undef ISSPEC + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); + +namespace types { + +inline size_t data_type_size(data_type_t data_type) { + using namespace data_type; + switch (data_type) { + case f32: return sizeof(prec_traits::type); + case s32: return sizeof(prec_traits::type); + case s8: return sizeof(prec_traits::type); + case u8: return sizeof(prec_traits::type); + case data_type::undef: + default: assert(!"unknown data_type"); + } + return 0; /* not supposed to be reachable */ +} + +inline format_kind_t format_tag_to_kind(format_tag_t tag) { + switch (tag) { + case format_tag::undef: return format_kind::undef; + case format_tag::any: return format_kind::any; + case format_tag::last: return format_kind::undef; + default: return format_kind::blocked; + } + + assert(!"unreachable"); + return format_kind::undef; +} + +inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs, + const memory_extra_desc_t &rhs) { + return true + && lhs.flags == rhs.flags + && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8, + lhs.compensation_mask == rhs.compensation_mask) + && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust, + lhs.scale_adjust == rhs.scale_adjust); +} + +inline bool blocking_desc_is_equal(const blocking_desc_t &lhs, + const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) { + using mkldnn::impl::utils::array_cmp; + return true + && lhs.inner_nblks == rhs.inner_nblks + && array_cmp(lhs.strides, rhs.strides, ndims) + && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) + && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); +} + +inline bool wino_desc_is_equal(const wino_desc_t &lhs, + const wino_desc_t &rhs) { + return lhs.wino_format == rhs.wino_format + && lhs.alpha == rhs.alpha + && lhs.ic == rhs.ic + && lhs.oc == rhs.oc + && lhs.ic_block == rhs.ic_block + && lhs.oc_block == rhs.oc_block + && lhs.ic2_block == rhs.ic2_block + && lhs.oc2_block == rhs.oc2_block + && lhs.r == rhs.r; +} + +inline bool rnn_packed_desc_is_equal( + const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { + bool ok = true + && lhs.format == rhs.format + && lhs.n_parts == rhs.n_parts + && lhs.offset_compensation == rhs.offset_compensation + && lhs.size == rhs.size + && lhs.n == rhs.n; + if (!ok) + return false; + + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.parts[i] == rhs.parts[i]; + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; + return ok; +} + +inline memory_desc_t zero_md() { + auto zero = memory_desc_t(); + return zero; +} + +inline bool is_zero_md(const memory_desc_t *md) { + return md == nullptr || *md == zero_md(); +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t dst_dt) { + using namespace utils; + using namespace data_type; + + if (one_of(f32, src_dt, dst_dt)) return f32; + if (one_of(s32, src_dt, dst_dt)) return s32; + + if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { + using namespace utils; + using namespace data_type; + using namespace prop_kind; + + /* prop_kind doesn't matter */ + if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32; + + if (one_of(prop_kind, forward_training, forward_inference)) { + if ((src_dt == u8 || src_dt == s8) + && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8)) + return s32; + } else if (prop_kind == backward_data) { + if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && + one_of(dst_dt, s8, u8)) + return s32; + } + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +} + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { + using namespace mkldnn::impl::utils; + bool base_equal = true + && lhs.ndims == rhs.ndims + && array_cmp(lhs.dims, rhs.dims, lhs.ndims) + && lhs.data_type == rhs.data_type + && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) + && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) + && lhs.offset0 == rhs.offset0 + && lhs.format_kind == rhs.format_kind; + if (!base_equal) return false; + if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; + if (lhs.format_kind == format_kind::blocked) + return types::blocking_desc_is_equal(lhs.format_desc.blocking, + rhs.format_desc.blocking, lhs.ndims); + else if (lhs.format_kind == format_kind::wino) + return types::wino_desc_is_equal(lhs.format_desc.wino_desc, + rhs.format_desc.wino_desc); + else if (lhs.format_kind == format_kind::rnn_packed) + return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, + rhs.format_desc.rnn_packed_desc); + return true; +} + +inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { + return !operator==(lhs, rhs); +} + +inline status_t memory_desc_init_by_strides(memory_desc_t &md, + const dims_t strides) { + return mkldnn_memory_desc_init_by_strides( + &md, md.ndims, md.dims, md.data_type, strides); +} + +inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + status_t status = mkldnn_memory_desc_init_by_tag( + &md, md.ndims, md.dims, md.data_type, tag); + if (status != status::success || strides == nullptr) + return status; + + /* TODO: add consistency check */ + + for (int d = 0; d < md.ndims; ++d) + md.format_desc.blocking.strides[d] = strides[d]; + + return status::success; +} + +/** inits memory descriptor based on logical dimensions kept in @p md, and the + * blocking structure @p blk. + * + * @note blk.strides represent the order only (from smaller to bigger) + * + * TODO: move md related functions to one single place + */ +inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md, + const blocking_desc_t &blk) { + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + dim_t block_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; + block_size *= blk.inner_blks[iblk]; + } + + for (int d = 0; d < md.ndims; ++d) { + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + md.padded_offsets[d] = 0; + } + md.offset0 = 0; + + md.format_kind = format_kind::blocked; + auto &mblk = md.format_desc.blocking; + mblk = blk; + + const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy + utils::array_copy(mblk.strides, blk.strides, ndims); + + int perm[MKLDNN_MAX_NDIMS]; + for (int d = 0; d < ndims; ++d) perm[d] = d; + + utils::simultaneous_sort(mblk.strides, perm, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + dim_t stride = block_size; + for (int _d = ndims - 1; _d >= 0; --_d) { + const int d = perm[_d]; + md.format_desc.blocking.strides[d] = stride; + stride *= md.padded_dims[d] / blocks[d]; + } + + md.extra = utils::zero(); + + return status::success; +} + +/** returns true if memory desc @p md corresponds to the given format tag and + * strides. + * If strides are not passed (or passed as nullptr) the dense structure is + * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns). + * Strides might contain `0` value, indicating the stride must match the one + * that mkldnn_memory_desc_init_by_tag() returns. + * Strides might contain `-1` values, that would be ignored during the + * comparison. For instance, this can be used if a stride along minibatch + * doesn't matter. */ +inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + if (md.format_kind != types::format_tag_to_kind(tag)) + return false; + + memory_desc_t md_gold; + status_t status = mkldnn_memory_desc_init_by_tag( + &md_gold, md.ndims, md.dims, md.data_type, tag); + if (status != status::success) return false; + + if (md.format_kind != format_kind::blocked) + return false; // unimplemented yet + + const auto &blk = md.format_desc.blocking; + const auto &blk_gold = md_gold.format_desc.blocking; + + using utils::array_cmp; + bool same_blocks = true + && blk.inner_nblks == blk_gold.inner_nblks + && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); + + if (!same_blocks) + return false; + + if (strides == nullptr) + return array_cmp(blk.strides, blk_gold.strides, md.ndims); + + for (int d = 0; d < md.ndims; ++d) { + dim_t stride = strides[d]; + if (stride == -1) continue; + if (stride == 0) stride = blk_gold.strides[d]; + if (blk.strides[d] != stride) return false; + } + + return true; +} + +/** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ +template +format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md, + Tags ...tags) { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(md, tag)) + return tag; + } + return format_tag::undef; +} + +} +} + +#include "memory_desc_wrapper.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp new file mode 100644 index 0000000000..d23f4682dc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#ifdef _WIN32 +#include +#include +#endif +#include +#include +#include + +#include "mkldnn.h" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +int getenv(const char *name, char *buffer, int buffer_size) { + if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0)) + return INT_MIN; + + int result = 0; + int term_zero_idx = 0; + size_t value_length = 0; + +#ifdef _WIN32 + value_length = GetEnvironmentVariable(name, buffer, buffer_size); +#else + const char *value = ::getenv(name); + value_length = value == NULL ? 0 : strlen(value); +#endif + + if (value_length > INT_MAX) + result = INT_MIN; + else { + int int_value_length = (int)value_length; + if (int_value_length >= buffer_size) { + result = -int_value_length; + } else { + term_zero_idx = int_value_length; + result = int_value_length; +#ifndef _WIN32 + strncpy(buffer, value, value_length); +#endif + } + } + + if (buffer != NULL) + buffer[term_zero_idx] = '\0'; + return result; +} + +int getenv_int(const char *name, int default_value) +{ + int value = default_value; + // # of digits in the longest 32-bit signed int + sign + terminating null + const int len = 12; + char value_str[len]; + if (getenv(name, value_str, len) > 0) + value = atoi(value_str); + return value; +} + +FILE *fopen(const char *filename, const char *mode) { +#ifdef _WIN32 + FILE *fp = NULL; + return ::fopen_s(&fp, filename, mode) ? NULL : fp; +#else + return ::fopen(filename, mode); +#endif +} + +void *malloc(size_t size, int alignment) { + void *ptr; + +#ifdef _WIN32 + ptr = _aligned_malloc(size, alignment); + int rc = ptr ? 0 : -1; +#else + int rc = ::posix_memalign(&ptr, alignment, size); +#endif + + return (rc == 0) ? ptr : 0; +} + +void free(void *p) { +#ifdef _WIN32 + _aligned_free(p); +#else + ::free(p); +#endif +} + +// Atomic operations +int32_t fetch_and_add(int32_t *dst, int32_t val) { +#ifdef _WIN32 + return InterlockedExchangeAdd(reinterpret_cast(dst), val); +#else + return __sync_fetch_and_add(dst, val); +#endif +} + +static int jit_dump_flag = 0; +static bool jit_dump_flag_initialized = false; +bool jit_dump_enabled() { + if (!jit_dump_flag_initialized) { + jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP"); + jit_dump_flag_initialized = true; + } + return jit_dump_flag != 0; +} + +} +} + +mkldnn_status_t mkldnn_set_jit_dump(int enabled) { + using namespace mkldnn::impl::status; + mkldnn::impl::jit_dump_flag = enabled; + mkldnn::impl::jit_dump_flag_initialized = true; + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp new file mode 100644 index 0000000000..d5a8ec5139 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp @@ -0,0 +1,370 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef UTILS_HPP +#define UTILS_HPP + +#include +#include +#include +#include +#include + +#if defined(__x86_64__) || defined(_M_X64) +#define MKLDNN_X86_64 +#endif + +#define MSAN_ENABLED 0 +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) +#undef MSAN_ENABLED +#define MSAN_ENABLED 1 +#include +#endif +#endif + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +// Sanity check for 64 bits +static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only"); + +#define CHECK(f) do { \ + status_t status = f; \ + if (status != status::success) \ + return status; \ +} while (0) + +#define IMPLICATION(cause, effect) (!(cause) || !!(effect)) + +namespace utils { + +/* a bunch of std:: analogues to be compliant with any msvs version + * + * Rationale: msvs c++ (and even some c) headers contain special pragma that + * injects msvs-version check into object files in order to abi-mismatches + * during the static linking. This makes sense if e.g. std:: objects are passed + * through between application and library, which is not the case for mkl-dnn + * (since there is no any c++-rt dependent stuff, ideally...). */ + +/* SFINAE helper -- analogue to std::enable_if */ +template struct enable_if {}; +template struct enable_if { typedef T type; }; + +/* analogue std::conditional */ +template struct conditional {}; +template struct conditional +{ typedef T type; }; +template struct conditional +{ typedef F type; }; + +template struct conditional3 {}; +template +struct conditional3 { typedef T type; }; +template +struct conditional3 { typedef FT type; }; +template +struct conditional3 { typedef FF type; }; + +template struct conditional_v {}; +template struct conditional_v +{ static constexpr U value = t; }; +template struct conditional_v +{ static constexpr U value = f; }; + +template struct remove_reference { typedef T type; }; +template struct remove_reference { typedef T type; }; +template struct remove_reference { typedef T type; }; + +template +inline T&& forward(typename utils::remove_reference::type &t) +{ return static_cast(t); } +template +inline T&& forward(typename utils::remove_reference::type &&t) +{ return static_cast(t); } + +template +inline typename remove_reference::type zero() +{ auto zero = typename remove_reference::type(); return zero; } + +template +inline bool everyone_is(T val, P item) { return val == item; } +template +inline bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} + +template +constexpr bool one_of(T val, P item) { return val == item; } +template +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template +inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } + +template +inline void array_copy(T *dst, const T *src, size_t size) { + for (size_t i = 0; i < size; ++i) dst[i] = src[i]; +} +template +inline bool array_cmp(const T *a1, const T *a2, size_t size) { + for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false; + return true; +} +template +inline void array_set(T *arr, const U& val, size_t size) { + for (size_t i = 0; i < size; ++i) arr[i] = static_cast(val); +} + +namespace product_impl { +template struct int2type{}; + +template +constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; } + +template +inline T product_impl(const T *arr, int2type) { + return arr[0]*product_impl(arr+1, int2type()); } +} + +template +inline T array_product(const T *arr) { + return product_impl::product_impl(arr, product_impl::int2type()); +} + +template +inline R array_product(const T *arr, size_t size) { + R prod = 1; + for (size_t i = 0; i < size; ++i) prod *= arr[i]; + return prod; +} + +/** sorts an array of values using @p comparator. While sorting the array + * of value, the function permutes an array of @p keys accordingly. + * + * @note The arrays of @p keys can be omitted. In this case the function + * sorts the array of @vals only. + */ +template +inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) { + if (size == 0) return; + + for (size_t i = 0; i < size - 1; ++i) { + bool swapped = false; + + for (size_t j = 0; j < size - i - 1; j++) { + if (comparator(vals[j], vals[j + 1]) > 0) { + nstl::swap(vals[j], vals[j + 1]); + if (keys) nstl::swap(keys[j], keys[j + 1]); + swapped = true; + } + } + + if (swapped == false) break; + } +} + +template +inline typename remove_reference::type div_up(const T a, const U b) { + assert(b); + return (a + b - 1) / b; +} + +template +inline typename remove_reference::type rnd_up(const T a, const U b) { + return div_up(a, b) * b; +} + +template +inline typename remove_reference::type rnd_dn(const T a, const U b) { + return (a / b) * b; +} + +template T *align_ptr(T *ptr, uintptr_t alignment) +{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); } + +template +inline U this_block_size(const T offset, const U max, const V block_size) { + assert(offset < max); + // TODO (Roma): can't use nstl::max() due to circular dependency... we + // need to fix this + const T block_boundary = offset + block_size; + if (block_boundary > max) + return max - offset; + else + return block_size; +} + +template +inline T nd_iterator_init(T start) { return start; } +template +inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) { + start = nd_iterator_init(start, utils::forward(tuple)...); + x = start % X; + return start / X; +} + +inline bool nd_iterator_step() { return true; } +template +inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) { + if (nd_iterator_step(utils::forward(tuple)...) ) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X) +{ + U max_jump = end - cur; + U dim_jump = X - x; + if (dim_jump <= max_jump) { + x = 0; + cur += dim_jump; + return true; + } else { + cur += max_jump; + x += max_jump; + return false; + } +} +template +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X, + Args &&... tuple) +{ + if (nd_iterator_jump(cur, end, utils::forward(tuple)...)) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template +inline T pick(size_t i, const T &x0) { return x0; } +template +inline T pick(size_t i, const T &x0, Args &&... args) { + return i == 0 ? x0 : pick(i - 1, utils::forward(args)...); +} + +template +T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference, + const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) { + switch (prop_kind) { + case prop_kind::forward_inference: return val_fwd_inference; + case prop_kind::forward_training: return val_fwd_training; + case prop_kind::backward_data: return val_bwd_d; + case prop_kind::backward_weights: return val_bwd_w; + default: assert(!"unsupported prop_kind"); + } + return T(); +} + +template +T pick_by_prop_kind(prop_kind_t prop_kind, + const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w) +{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); } + +template +struct array_offset_calculator { + template + array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... } + { + _base_ptr = base; + } + template + inline Telem &operator()(Targs... Fargs) + { + return *(_base_ptr + _offset(1, Fargs...)); + } + +private: + template + inline size_t _offset(size_t const dimension, size_t element) + { + return element; + } + + template + inline size_t _offset(size_t const dimension, size_t theta, size_t element) + { + return element + (_dims[dimension] * theta); + } + + template + inline size_t _offset(size_t const dimension, size_t theta, size_t element, + Targs... Fargs) + { + size_t t_prime = element + (_dims[dimension] * theta); + return _offset(dimension + 1, t_prime, Fargs...); + } + + Telem *_base_ptr; + const int _dims[Tdims]; +}; + +} + +int32_t fetch_and_add(int32_t *dst, int32_t val); +inline void yield_thread() {} + +// Reads an environment variable 'name' and stores its string value in the +// 'buffer' of 'buffer_size' bytes on success. +// +// - Returns the length of the environment variable string value (excluding +// the terminating 0) if it is set and its contents (including the terminating +// 0) can be stored in the 'buffer' without truncation. +// +// - Returns negated length of environment variable string value and writes +// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to +// store the value (including the terminating 0) without truncation. +// +// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment +// variable is not set. +// +// - Returns INT_MIN if the 'name' is NULL. +// +// - Returns INT_MIN if the 'buffer_size' is negative. +// +// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than +// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to +// retrieve the length of the environment variable value string. +// +int getenv(const char *name, char *buffer, int buffer_size); +// Reads an integer from the environment +int getenv_int(const char *name, int default_value = 0); +bool jit_dump_enabled(); +FILE *fopen(const char *filename, const char *mode); + +constexpr int msan_enabled = MSAN_ENABLED; +inline void msan_unpoison(void *ptr, size_t size) { +#if MSAN_ENABLED + __msan_unpoison(ptr, size); +#endif +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp new file mode 100644 index 0000000000..89a57772cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp @@ -0,0 +1,665 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#ifndef _WIN32 +#include +#endif + +#include "mkldnn.h" +#include "mkldnn_version.h" +#include "c_types_map.hpp" +#include "verbose.hpp" +#include "cpu/cpu_isa_traits.hpp" + +#include "batch_normalization_pd.hpp" +#include "pooling_pd.hpp" +#include "concat_pd.hpp" +#include "reorder_pd.hpp" +#include "convolution_pd.hpp" +#include "rnn_pd.hpp" +#include "deconvolution_pd.hpp" +#include "shuffle_pd.hpp" +#include "eltwise_pd.hpp" +#include "softmax_pd.hpp" +#include "inner_product_pd.hpp" +#include "sum_pd.hpp" +#include "lrn_pd.hpp" + +/* MKL-DNN CPU ISA info */ +#define ISA_ANY "No instruction set specific optimizations" +#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)" +#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)" +#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)" +#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512)" +#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions" +#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \ + "AVX512-DL Boost)" +#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions" +#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions" + +namespace mkldnn { +namespace impl { + +static verbose_t verbose; +static bool initialized; +static bool version_printed = false; + +const verbose_t *mkldnn_verbose() { +#if !defined(DISABLE_VERBOSE) + if (!initialized) { + const int len = 2; + char val[len] = {0}; + if (getenv("MKLDNN_VERBOSE", val, len) == 1) + verbose.level = atoi(val); + initialized = true; + } + if (!version_printed && verbose.level > 0) { + printf("mkldnn_verbose,info," + "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n", + mkldnn_version()->major, mkldnn_version()->minor, + mkldnn_version()->patch, mkldnn_version()->hash, + get_isa_info()); + version_printed = true; + } +#else + verbose.level = 0; +#endif + return &verbose; +} + +double get_msec() { +#ifdef _WIN32 + static LARGE_INTEGER frequency; + if (frequency.QuadPart == 0) + QueryPerformanceFrequency(&frequency); + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + return 1e+3 * now.QuadPart / frequency.QuadPart; +#else + struct timeval time; + gettimeofday(&time, NULL); + return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; +#endif +} + +const char *get_isa_info() { + using namespace mkldnn::impl::cpu; + if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS; + if (mayiuse(avx512_mic)) return AVX512_MIC; + if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI; + if (mayiuse(avx512_core)) return AVX512_CORE; + if (mayiuse(avx512_common)) return AVX512_COMMON; + if (mayiuse(avx2)) return AVX2; + if (mayiuse(avx)) return AVX; + if (mayiuse(sse42)) return SSE42; + return ISA_ANY; +} + +/* init_info section */ +namespace { +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_DAT_LEN 256 +#define MKLDNN_VERBOSE_AUX_LEN 384 +#define MKLDNN_VERBOSE_PRB_LEN 384 + +#define DECL_DAT_AUX_PRB_STRS() \ + int dat_written = 0, aux_written = 0, prb_written = 0; \ + MAYBE_UNUSED((dat_written * aux_written * prb_written)); \ + char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \ + char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \ + char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str) + +#define DFMT "%" PRId64 + +void clear_buf(char *buf, int &written) { + /* TODO: do it better */ + buf[0] = '#'; + buf[1] = '\0'; + written = 1; +} + +#define DPRINT(buf, buf_len, written, ...) do { \ + int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \ + if (l < 0 || written + l > buf_len) { \ + clear_buf(buf, written); \ + } else { \ + written += l; \ + } \ +} while(0) + +// XXX: Outputs strings corresponding to memory formats used for data tensors. +void format_prb_desc_str(char *str, int len, const memory_desc_t *md) { + const auto dims = md->dims; + int written = 0; + if (md->ndims == 1) + DPRINT(str, len, written, + "x" DFMT, dims[0]); + else if (md->ndims == 2) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT, dims[0], dims[1]); + else if (md->ndims == 3) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "iw" DFMT, + dims[0], dims[1], dims[2]); + else if (md->ndims == 4) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3]); + else if (md->ndims == 5) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3], dims[4]); + else + mkldnn_md2dim_str(str, len, md); +} + +void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind, + const char *impl_str, mkldnn_prop_kind_t prop_kind, + const char *data_str, const char *aux_str, const char *prb_str) { + MAYBE_UNUSED(verbose_templ); + int written = 0; + DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s", + mkldnn_prim_kind2str(prim_kind), impl_str, + mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str); +} + +template static void init_info_bnorm(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "flags:%u", s->desc()->flags); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_conv(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->ndims() == 5) { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } else { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_shuffle(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md(); + + if (1) { // data + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "axis:%d group_size:" DFMT, s->axis(), s->group_size()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_eltwise(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_iprod(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_lrn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_mem(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "num:%d", s->n_inputs()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str, + aux_str, prb_str); +} + +template static void init_info_pool(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->is_fwd() ? s->src_md() : s->diff_src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // ws + auto md = s->workspace_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->is_3d()) { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "", + s->MB(), s->C(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } else { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT, + s->MB(), s->C(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_softmax(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template static void init_info_rnn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src layer + auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // src iter + auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_layer + auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_iter + auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bias + auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst layer + auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst iter + auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + alg_kind_t alg_kind = s->cell_kind(); + rnn_direction_t rnn_dir = s->direction(); + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s_%s", mkldnn_alg_kind2str(alg_kind), + mkldnn_rnn_direction2str(rnn_dir)); + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "l" DFMT "t" DFMT "mb" DFMT + "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT, + s->L(), s->T(), s->MB(), + s->SIC(), s->SLC(), s->DIC(), s->DLC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +#undef DPRINT + +#else // !defined(DISABLE_VERBOSE) + +#define DEFINE_STUB(name) \ + template \ + static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \ + { UNUSED(s); UNUSED(buffer); } + +DEFINE_STUB(bnorm); +DEFINE_STUB(conv); +DEFINE_STUB(eltwise); +DEFINE_STUB(iprod); +DEFINE_STUB(lrn); +DEFINE_STUB(mem); +DEFINE_STUB(pool); +DEFINE_STUB(softmax); +DEFINE_STUB(rnn); +DEFINE_STUB(shuffle); +#undef DEFINE_STUB + +#endif // !defined(DISABLE_VERBOSE) +} + +void init_info(batch_normalization_pd_t *s, char *b) +{ init_info_bnorm(s, b); } +void init_info(concat_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(convolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(deconvolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(eltwise_pd_t *s, char *b) +{ init_info_eltwise(s, b); } +void init_info(inner_product_pd_t *s, char *b) +{ init_info_iprod(s, b); } +void init_info(lrn_pd_t *s, char *b) +{ init_info_lrn(s, b); } +void init_info(pooling_pd_t *s, char *b) +{ init_info_pool(s, b); } +void init_info(reorder_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(rnn_pd_t *s, char *b) +{ init_info_rnn(s, b); } +void init_info(shuffle_pd_t *s, char *b) +{ init_info_shuffle(s, b); } +void init_info(softmax_pd_t *s, char *b) +{ init_info_softmax(s, b); } +void init_info(sum_pd_t *s, char *b) +{ init_info_mem(s, b); } + +} +} + +mkldnn_status_t mkldnn_set_verbose(int level) { + using namespace mkldnn::impl::status; + if (level < 0 || level > 2) return invalid_arguments; + mkldnn::impl::verbose.level = level; + mkldnn::impl::initialized = true; + return success; +} + +const mkldnn_version_t *mkldnn_version() { + static mkldnn_version_t ver = { + MKLDNN_VERSION_MAJOR, + MKLDNN_VERSION_MINOR, + MKLDNN_VERSION_PATCH, + MKLDNN_VERSION_HASH}; + return &ver; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp new file mode 100644 index 0000000000..e3049750cb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef VERBOSE_HPP +#define VERBOSE_HPP + +#include +#include + +#include "mkldnn_debug.h" +#include "c_types_map.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +struct verbose_t { + int level; +}; + +const verbose_t *mkldnn_verbose(); +double get_msec(); +const char *get_isa_info(); + +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_BUF_LEN 1024 +#else +#define MKLDNN_VERBOSE_BUF_LEN 1 +#endif + +void init_info(batch_normalization_pd_t *s, char *buffer); +void init_info(concat_pd_t *s, char *buffer); +void init_info(convolution_pd_t *s, char *buffer); +void init_info(deconvolution_pd_t *s, char *buffer); +void init_info(eltwise_pd_t *s, char *buffer); +void init_info(inner_product_pd_t *s, char *buffer); +void init_info(lrn_pd_t *s, char *buffer); +void init_info(pooling_pd_t *s, char *buffer); +void init_info(reorder_pd_t *s, char *buffer); +void init_info(rnn_pd_t *s, char *buffer); +void init_info(shuffle_pd_t *s, char *buffer); +void init_info(softmax_pd_t *s, char *buffer); +void init_info(sum_pd_t *s, char *buffer); + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp new file mode 100644 index 0000000000..520bd4710b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef Z_MAGIC_HPP +#define Z_MAGIC_HPP + +#define CHAIn2(a,b) a b +#define CHAIN2(a,b) CHAIn2(a,b) + +#define CONCAt2(a,b) a ## b +#define CONCAT2(a,b) CONCAt2(a,b) + +#define STRINGIFy(s) #s +#define STRINGIFY(s) STRINGIFy(s) + +#ifdef _MSC_VER +# define PRAGMA_MACRo(x) __pragma(x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#else +# define PRAGMA_MACRo(x) _Pragma(#x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#endif + +#define UNUSED(x) ((void)x) +#define MAYBE_UNUSED(x) UNUSED(x) + +#if defined(_WIN32) && !defined(__GNUC__) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp new file mode 100644 index 0000000000..7cf7822d90 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr) { +# define BAR_CTR_OFF offsetof(ctx_t, ctr) +# define BAR_SENSE_OFF offsetof(ctx_t, sense) + using namespace Xbyak; + + Xbyak::Reg64 reg_tmp = [&]() { + /* returns register which is neither reg_ctx nor reg_nthr */ + Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx }; + for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i) + if (!utils::one_of(regs[i], reg_ctx, reg_nthr)) + return regs[i]; + return regs[0]; /* should not happen */ + }(); + + Label barrier_exit_label, barrier_exit_restore_label, spin_label; + + code.cmp(reg_nthr, 1); + code.jbe(barrier_exit_label); + + code.push(reg_tmp); + + /* take and save current sense */ + code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.push(reg_tmp); + code.mov(reg_tmp, 1); + + if (mayiuse(avx512_mic)) { + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + } + + code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp); + code.add(reg_tmp, 1); + code.cmp(reg_tmp, reg_nthr); + code.pop(reg_tmp); /* restore previous sense */ + code.jne(spin_label); + + /* the last thread {{{ */ + code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx + + // notify waiting threads + code.not_(reg_tmp); + code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp); + code.jmp(barrier_exit_restore_label); + /* }}} the last thread */ + + code.CodeGenerator::L(spin_label); + code.pause(); + code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.je(spin_label); + + code.CodeGenerator::L(barrier_exit_restore_label); + code.pop(reg_tmp); + + code.CodeGenerator::L(barrier_exit_label); +# undef BAR_CTR_OFF +# undef BAR_SENSE_OFF +} + +/** jit barrier generator */ +struct jit_t: public jit_generator { + void (*barrier)(ctx_t *ctx, size_t nthr); + + jit_t() { + generate(*this, abi_param1, abi_param2); + ret(); + barrier = reinterpret_cast(const_cast( + this->getCode())); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t) +}; + +void barrier(ctx_t *ctx, int nthr) { + static jit_t j; /* XXX: constructed on load ... */ + j.barrier(ctx, nthr); +} + +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp new file mode 100644 index 0000000000..0f55e33aa8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BARRIER_HPP +#define CPU_BARRIER_HPP + +#include + +#include "jit_generator.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +STRUCT_ALIGN(64, +struct ctx_t { + enum { CACHE_LINE_SIZE = 64 }; + volatile size_t ctr; + char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; + volatile size_t sense; + char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; +}); + +inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero(); } +void barrier(ctx_t *ctx, int nthr); + +/** injects actual barrier implementation into another jitted code + * @params: + * code -- jit_generator object where the barrier is to be injected + * reg_ctx -- read-only register with pointer to the barrier context + * reg_nnthr -- read-only register with the # of synchronizing threads + */ +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr); + +} + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp new file mode 100644 index 0000000000..1ed5ad57b9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_PD_HPP +#define CPU_BATCH_NORMALIZATION_PD_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t { + using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t; +}; + +struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t { + using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp new file mode 100644 index 0000000000..b8d5c4fcaf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "cpu_batch_normalization_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters) { + int nthrs = mkldnn_get_max_threads(); + int l3_size = get_cache_size(3, true) * nthrs / 2; + + C_blks_per_iter = l3_size / working_set_size; + + if (C_blks_per_iter == 0) + C_blks_per_iter = 1; + if (C_blks_per_iter > C_blks) + C_blks_per_iter = C_blks; + + iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter; +} + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) { + if (nthr <= C_blks || !mkldnn_thr_syncable()) { + C_ithr = ithr; C_nthr = nthr; + N_ithr = 0; N_nthr = 1; + S_ithr = 0; S_nthr = 1; + N_s = 0; N_e = N; S_s = 0; S_e = SP; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + } else { + if (do_blocking) { + N_nthr = (int)nstl::min(N, nthr); + C_nthr = (int)nstl::min(C_blks, nthr / N_nthr); + S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + C_nthr = (int)math::gcd((dim_t)nthr, C_blks); + N_nthr = (int)nstl::min(N, nthr / C_nthr); + S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + if (!spatial_thr_allowed) + S_nthr = 1; + + if (S_nthr < 1) S_nthr = 1; + if (ithr < C_nthr * N_nthr * S_nthr) { + N_ithr = (ithr / S_nthr) % N_nthr ; + C_ithr = ithr / (N_nthr * S_nthr); + S_ithr = ithr % S_nthr; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + balance211(N, N_nthr, N_ithr, N_s, N_e); + balance211(SP, S_nthr, S_ithr, S_s, S_e); + } else { + S_ithr = N_ithr = C_ithr = -ithr; + S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1; + } + } + + // spatial_thr_allowed is meant to help maintain + // consistent decisions about spatial threading + // between mutiple invocations of this routine. + // It is caller's responsibility to check the + // return value and pass it as a flag to the + // next call if needed. + if (S_nthr == 1) + spatial_thr_allowed = false; + + return spatial_thr_allowed; +} + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size) { + if (!mkldnn_thr_syncable()) return false; + + dim_t nthr = mkldnn_get_max_threads(); + dim_t SP = bdesc->W() * bdesc->D() * bdesc->H(); + dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md()) + .padded_dims()[1]; + assert(C_PADDED % simd_w == 0); + + size_t data = bdesc->MB() * C_PADDED * SP * data_size; + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0); + dim_t C_blks_per_iter{ 1 }, iters{ 1 }; + dim_t C_blks = C_PADDED / simd_w; + + if (do_blocking) { + int num_tensors = bdesc->is_fwd() ? 1 : 2; + size_t working_set_size + = (bdesc->MB() * SP * simd_w * data_size) * num_tensors; + cache_balance(working_set_size, C_blks, C_blks_per_iter, iters); + } + + // Spatial threading decision made in this function shall be consistent + // with thread_balance() behavior. + C_blks = do_blocking ? C_blks_per_iter : C_blks; + + if (nthr <= C_blks) return false; + + dim_t S_nthr = 1; + if (do_blocking) { + dim_t N_nthr = nstl::min(bdesc->MB(), nthr); + dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + dim_t C_nthr = math::gcd(nthr, C_blks); + dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + return S_nthr > 1; +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp new file mode 100644 index 0000000000..0daef0716c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp @@ -0,0 +1,43 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP +#define CPU_BATCH_NORMALIZATION_UTILS_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters); + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e); + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size); + +} +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp new file mode 100644 index 0000000000..b926491202 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_concat.hpp" +#include "cpu/simple_concat.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const cpd_create_f cpu_concat_impl_list[] = { + /* + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(simple_concat_t), + INSTANCE(ref_concat_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const { + return cpu_concat_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp new file mode 100644 index 0000000000..0b01bcf163 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONCAT_PD_HPP +#define CPU_CONCAT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "concat_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_concat_pd_t: public concat_pd_t { + using concat_pd_t::concat_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp new file mode 100644 index 0000000000..52a38a2294 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONVOLUTION_PD_HPP +#define CPU_CONVOLUTION_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t { + using convolution_fwd_pd_t::convolution_fwd_pd_t; + + bool has_padded_dst() const { + memory_desc_wrapper dst_d(&dst_md_); + return OC() != dst_d.padded_dims()[1]; + } + + bool wants_padded_bias() const { + if (!with_bias()) return false; + return has_padded_dst(); + } + + bool wants_zero_pad_dst(bool jit_impl = true) const { + if (!has_padded_dst()) return false; + const auto &po = attr()->post_ops_; + int idx; + if ((idx = po.find(primitive_kind::eltwise)) == -1) return false; + return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg, + jit_impl); + } +}; + +struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t { + using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t; +}; + +struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t { + using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t; + + bool wants_padded_bias() const { + if (!with_bias()) return false; + memory_desc_wrapper diff_dst_d(&diff_dst_md_); + return OC() != diff_dst_d.padded_dims()[1]; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp new file mode 100644 index 0000000000..164c8601d7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_DECONVOLUTION_PD_HPP +#define CPU_DECONVOLUTION_PD_HPP + +#include + +#include "deconvolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t { + using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t; +}; + +struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t { + using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t; +}; + +struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t { + using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp new file mode 100644 index 0000000000..c52f00026e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ELTWISE_PD_HPP +#define CPU_ELTWISE_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "eltwise_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t { + using eltwise_fwd_pd_t::eltwise_fwd_pd_t; +}; + +struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t { + using eltwise_bwd_pd_t::eltwise_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp new file mode 100644 index 0000000000..ce0a3667ad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp @@ -0,0 +1,324 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "type_helpers.hpp" +#include "verbose.hpp" + +#include "cpu_engine.hpp" +#include "cpu_memory.hpp" + +//#include "cpu/rnn/ref_rnn.hpp" + +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" +//#include "cpu/jit_avx512_common_1x1_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp" +#include "cpu/jit_avx512_common_convolution_winograd.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp" +#include "cpu/jit_avx512_common_convolution.hpp" +//#include "cpu/jit_avx2_1x1_convolution.hpp" +//#include "cpu/jit_sse42_1x1_convolution.hpp" +#include "cpu/jit_avx2_convolution.hpp" +#include "cpu/jit_sse42_convolution.hpp" +//#include "cpu/gemm_convolution.hpp" +//#include "cpu/gemm_x8s8s32x_convolution.hpp" +//#include "cpu/ref_convolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" +//#include "cpu/ref_deconvolution.hpp" +//#include "cpu/ref_shuffle.hpp" +//#include "cpu/jit_uni_eltwise.hpp" +//#include "cpu/ref_eltwise.hpp" +//#include "cpu/ref_softmax.hpp" +#include "cpu/jit_uni_pooling.hpp" +//#include "cpu/jit_uni_i8i8_pooling.hpp" +//#include "cpu/ref_pooling.hpp" +//#include "cpu/nchw_pooling.hpp" +//#include "cpu/nhwc_pooling.hpp" +//#include "cpu/jit_avx512_common_lrn.hpp" +//#include "cpu/jit_uni_lrn.hpp" +//#include "cpu/ref_lrn.hpp" +//#include "cpu/jit_uni_batch_normalization.hpp" +//#include "cpu/ref_batch_normalization.hpp" +//#include "cpu/ncsp_batch_normalization.hpp" +//#include "cpu/nspc_batch_normalization.hpp" +//#include "cpu/ref_inner_product.hpp" +//#include "cpu/gemm_inner_product.hpp" +//#include "cpu/gemm_x8s8s32x_inner_product.hpp" +//#include "cpu/jit_uni_dw_convolution.hpp" +//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +status_t cpu_engine_t::memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) { + auto _memory = new cpu_memory_t(this, md, handle); + if (_memory == nullptr) + return status::out_of_memory; + + status_t status = _memory->init(); + if (status != status::success) { + delete _memory; + return status; + } + + return safe_ptr_assign(*memory, _memory); +} + +using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; + +#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> +static const pd_create_f cpu_impl_list[] = { + /* RNN */ + /* + INSTANCE(ref_rnn_fwd_f32_t), + INSTANCE(ref_rnn_fwd_u8s8_t), + INSTANCE(ref_rnn_bwd_f32_t), + */ + /* conv */ + /* + INSTANCE(jit_avx512_common_dw_convolution_fwd_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t), + */ + INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t), + INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_winograd_fwd_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_fwd_t), + //INSTANCE(jit_avx512_common_convolution_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_bwd_weights_t), + /* + INSTANCE(jit_avx2_dw_convolution_fwd_t), + INSTANCE(jit_avx2_dw_convolution_bwd_data_t), + INSTANCE(jit_avx2_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx2_1x1_convolution_fwd_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_data_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t), + INSTANCE(jit_sse42_dw_convolution_fwd_t), + INSTANCE(jit_sse42_dw_convolution_bwd_data_t), + INSTANCE(jit_sse42_dw_convolution_bwd_weights_t), + INSTANCE(jit_sse42_1x1_convolution_fwd_t), + */ + INSTANCE(jit_avx2_convolution_fwd_t), + //INSTANCE(jit_avx2_convolution_bwd_data_t), + //INSTANCE(jit_avx2_convolution_bwd_weights_t), + INSTANCE(jit_sse42_convolution_fwd_t), + /* + INSTANCE(gemm_convolution_fwd_t), + INSTANCE(gemm_convolution_bwd_data_t), + INSTANCE(gemm_convolution_bwd_weights_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_weights_t), + */ + /* conv (int) */ + /* + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_fwd_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + INSTANCE(ref_convolution_bwd_data_t), + */ + /* deconv */ + /* + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), + INSTANCE(ref_deconvolution_bwd_weights_t), + INSTANCE(ref_deconvolution_bwd_data_t), + INSTANCE(ref_deconvolution_fwd_t), + */ + /* shuffle */ + /* + INSTANCE(ref_shuffle_t<4>), // f32 or s32 + INSTANCE(ref_shuffle_t<1>), // s8 or u8 + */ + /* eltwise */ + /* + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(jit_uni_eltwise_fwd_t), + INSTANCE(jit_uni_eltwise_bwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_bwd_t), + */ + /* eltwise (int) */ + /* + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_fwd_t), + INSTANCE(ref_eltwise_bwd_t), + */ + /* softmax */ + /* + INSTANCE(ref_softmax_fwd_t), + INSTANCE(ref_softmax_bwd_t), + */ + /* pool */ + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + INSTANCE(jit_uni_pooling_fwd_t), + //INSTANCE(jit_uni_pooling_bwd_t), + /* + INSTANCE(nchw_pooling_fwd_t), + INSTANCE(nchw_pooling_bwd_t), + INSTANCE(nhwc_pooling_fwd_t), + INSTANCE(nhwc_pooling_bwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_bwd_t), + */ + /* pool (int) */ + /* + INSTANCE(jit_uni_i8i8_pooling_fwd_t), + INSTANCE(jit_uni_i8i8_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_fwd_t), + INSTANCE(ref_pooling_bwd_t), + */ + /* lrn */ + /* + INSTANCE(jit_avx512_common_lrn_fwd_t), + INSTANCE(jit_avx512_common_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t), + INSTANCE(jit_uni_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t), + INSTANCE(ref_lrn_fwd_t), + INSTANCE(ref_lrn_bwd_t), + */ + /* batch normalization */ + /* + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(jit_uni_batch_normalization_fwd_t), + INSTANCE(jit_uni_batch_normalization_bwd_t), + INSTANCE(ncsp_batch_normalization_fwd_t), + INSTANCE(ncsp_batch_normalization_bwd_t), + INSTANCE(nspc_batch_normalization_fwd_t), + INSTANCE(nspc_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t), + INSTANCE(ref_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t), + */ + /* inner product */ + /* + INSTANCE(gemm_inner_product_fwd_t), + INSTANCE(gemm_inner_product_bwd_data_t), + INSTANCE(gemm_inner_product_bwd_weights_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_bwd_data_t), + INSTANCE(ref_inner_product_bwd_weights_t), + */ + /* inner product (int) */ + /* + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + INSTANCE(ref_inner_product_fwd_t), + */ + /* eol */ + nullptr, +}; +#undef INSTANCE +} + +const pd_create_f* cpu_engine_t::get_implementation_list() const { + return cpu_impl_list; +} + +cpu_engine_factory_t engine_factory; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp new file mode 100644 index 0000000000..e4c877ee05 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp @@ -0,0 +1,70 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ENGINE_HPP +#define CPU_ENGINE_HPP + +#include + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "../common/engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class cpu_engine_t: public engine_t { +public: + cpu_engine_t(): engine_t(engine_kind::cpu) {} + + /* implementation part */ + + virtual status_t memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) override; + + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const override; + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const override; + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const override; + virtual const primitive_desc_create_f* + get_implementation_list() const override; +}; + +class cpu_engine_factory_t: public engine_factory_t { +public: + virtual size_t count() const override { return 1; } + virtual engine_kind_t kind() const override { return engine_kind::cpu; } + virtual status_t engine_create(engine_t **engine, + size_t index) const override { + assert(index == 0); + *engine = new cpu_engine_t(); + return status::success; + }; +}; + +extern cpu_engine_factory_t engine_factory; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp new file mode 100644 index 0000000000..5880d3450c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_INNER_PRODUCT_PD_HPP +#define CPU_INNER_PRODUCT_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "inner_product_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { +inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { + using namespace utils; + + auto strides_compatible = [&]() { + bool ok = true; + auto w_str = wei_d.blocking_desc().strides; + auto d_str = src_d.blocking_desc().strides; + for (int i = 1; i < src_d.ndims() - 1; i++) { + ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; + } + return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); + }; + return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc() + && src_d.ndims() == wei_d.ndims() + && src_d.blocking_desc().inner_nblks + == wei_d.blocking_desc().inner_nblks + && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1) + && array_cmp(src_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_nblks) + && array_cmp(src_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_nblks) + && strides_compatible() + && dst_d.matches_tag(format_tag::nc) + && src_d.only_padded_dim(1) + && wei_d.only_padded_dim(1) + && src_d.padded_dims()[1] == wei_d.padded_dims()[1] + && src_d.is_dense(true) + && dst_d.is_dense() + && wei_d.is_dense(true); +} +} + +struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t { + using inner_product_fwd_pd_t::inner_product_fwd_pd_t; +}; + +struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t { + using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; +}; + +struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t { + using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp new file mode 100644 index 0000000000..da6e9dac8e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp @@ -0,0 +1,151 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ISA_TRAITS_HPP +#define CPU_ISA_TRAITS_HPP + +#include + +#define XBYAK64 +#define XBYAK_NO_OP_NAMES +/* in order to make selinux happy memory that would be marked with X-bit should + * be obtained with mmap */ +#define XBYAK_USE_MMAP_ALLOCATOR +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +/* turn off `size_t to other-type implicit casting` warning + * currently we have a lot of jit-generated instructions that + * take uint32_t, but we pass size_t (e.g. due to using sizeof). + * FIXME: replace size_t parameters with the appropriate ones */ +#pragma warning (disable: 4267) +#endif +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + isa_any, + sse41, + sse42, + avx, + avx2, + avx512_common, + avx512_core, + avx512_core_vnni, + avx512_mic, + avx512_mic_4ops, +} cpu_isa_t; + +template struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */ + +template <> struct cpu_isa_traits { + typedef Xbyak::Xmm Vmm; + static constexpr int vlen_shift = 4; + static constexpr int vlen = 16; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits { + typedef Xbyak::Ymm Vmm; + static constexpr int vlen_shift = 5; + static constexpr int vlen = 32; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits { + typedef Xbyak::Zmm Vmm; + static constexpr int vlen_shift = 6; + static constexpr int vlen = 64; + static constexpr int n_vregs = 32; +}; +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +template <> struct cpu_isa_traits: + public cpu_isa_traits {}; + +namespace { + +static Xbyak::util::Cpu cpu; +static inline bool mayiuse(const cpu_isa_t cpu_isa) { + using namespace Xbyak::util; + + switch (cpu_isa) { + case sse41: + case sse42: + // FIXME: SSE4.2 is actually NOT required + //return cpu.has(Cpu::tSSE42); + return cpu.has(Cpu::tSSE41); + case avx: + return cpu.has(Cpu::tAVX); + case avx2: + return cpu.has(Cpu::tAVX2); + case avx512_common: + return cpu.has(Cpu::tAVX512F); + case avx512_core: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ); + case avx512_core_vnni: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ) + && cpu.has(Cpu::tAVX512_VNNI); + case avx512_mic: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512CD) + && cpu.has(Cpu::tAVX512ER) + && cpu.has(Cpu::tAVX512PF); + case avx512_mic_4ops: + return true + && mayiuse(avx512_mic) + && cpu.has(Cpu::tAVX512_4FMAPS) + && cpu.has(Cpu::tAVX512_4VNNIW); + case isa_any: + return true; + } + return false; +} +} + +/* whatever is required to generate string literals... */ +#include "z_magic.hpp" +#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \ + (isa == sse42 ? prefix STRINGIFY(sse42) : \ + (isa == avx ? prefix STRINGIFY(avx) : \ + (isa == avx2 ? prefix STRINGIFY(avx2) : \ + (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \ + (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \ + (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \ + (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \ + prefix suffix_if_any))))))) + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp new file mode 100644 index 0000000000..49988f4c2d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_LRN_PD_HPP +#define CPU_LRN_PD_HPP + +#include + +#include "lrn_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t { + using lrn_fwd_pd_t::lrn_fwd_pd_t; +}; + +struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t { + using lrn_bwd_pd_t::lrn_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp new file mode 100644 index 0000000000..3c0624cf46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn_traits.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; + +enum blk_kind_t { a, b, c, ab, ba, bc, cb }; + +template +void typed_zero_pad_blk( + const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { + using data_t = typename prec_traits
::type; + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + const auto &blk = m_d.blocking_desc(); + auto dim_is_blocked = [&](int dim) { + for (int i = 0; i < blk.inner_nblks; i++) + if (blk.inner_idxs[i] == dim) + return true; + return false; + }; + bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1), + C_blocked = dim_is_blocked(2); + + assert(blk.inner_nblks < 4); + assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked) + || (C_blocked && B_blocked)); + + const int a_tail_s = A_blocked ? dims[0] % blksize : 0; + const int b_tail_s = B_blocked ? dims[1] % blksize : 0; + const int c_tail_s = C_blocked ? dims[2] % blksize : 0; + assert(a_tail_s || b_tail_s || c_tail_s); + + const int A = A_blocked ? pdims[0] / blksize : dims[0]; + const int B = B_blocked ? pdims[1] / blksize : dims[1]; + const int C = C_blocked ? pdims[2] / blksize : dims[2]; + const int D = m_d.ndims() > 3 ? dims[3] : 1; + const int E = m_d.ndims() > 4 ? dims[4] : 1; + const int F = m_d.ndims() > 5 ? dims[5] : 1; + const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1; + + auto zeroize_tail = [&](data_t *d, const int tail_s) { + for (int b = tail_s; b < blksize; ++b) + d[b] = 0; + }; + auto zeroize_tail_inner = [&](data_t *d, const int tail_s) { + for (int b1 = 0; b1 < blksize; ++b1) + for (int b2 = tail_s; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + auto zeroize_tail_outer = [&](data_t *d, const int tail_s) { + for (int b1 = tail_s; b1 < blksize; ++b1) + for (int b2 = 0; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + + if (c_tail_s) { + parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)]; + if (blk_kind == c) + zeroize_tail(x, c_tail_s); + else if (blk_kind == bc) + zeroize_tail_inner(x, c_tail_s); + else if (blk_kind == cb) + zeroize_tail_outer(x, c_tail_s); + }); + } + + if (b_tail_s) { + parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)]; + if (blk_kind == b) + zeroize_tail(x, b_tail_s); + else if (blk_kind == ab || blk_kind == cb) + zeroize_tail_inner(x, b_tail_s); + else if (blk_kind == ba || blk_kind == bc) + zeroize_tail_outer(x, b_tail_s); + }); + } + + if (a_tail_s) { + parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)]; + if (blk_kind == a) + zeroize_tail(x, a_tail_s); + else if (blk_kind == ba) + zeroize_tail_inner(x, a_tail_s); + else if (blk_kind == ab) + zeroize_tail_outer(x, a_tail_s); + }); + } +} + +/* + * all + */ +template +void typed_zero_pad_generic_blocked( + const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { + const int ndims = m_d.ndims(); + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + + const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true); + + /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1] + * | \ / + * | --------------------- + * has contiguous + * padding + * + * step <-- D_k+1 * ... * D_ndims-1 + * step_dim <-- k + */ + + ptrdiff_t step = 1; + int step_dim = ndims - 1; + for (; step_dim >= 0; --step_dim) { + if (dims[step_dim] != pdims[step_dim]) + break; + step *= dims[step_dim]; + } + + assert(step_dim >= 0 && "no zero padding is required"); + if (step_dim < 0) + return; + + parallel_nd(nelems / step, [&](ptrdiff_t e1) { + bool need_zero = false; + + ptrdiff_t idx = e1; + for (int d = step_dim; d >= 0; --d) { + if (idx % pdims[d] >= dims[d]) { + need_zero = true; + break; + } + idx /= pdims[d]; + } + + if (need_zero) { + for (ptrdiff_t e0 = 0; e0 < step; ++e0) + data[m_d.off_l(e1 * step + e0, true)] = 0; + } + }); +} + +template +status_t cpu_memory_t::typed_zero_pad() const { + const memory_desc_wrapper mdw(md()); + + if (mdw.format_kind() != format_kind::blocked) + return unimplemented; + + if (mdw.nelems(false) == mdw.nelems(true)) + return success; + + auto *data = (typename prec_traits
::type *)data_; + auto blk = mdw.blocking_desc(); + + auto get_blksize = [&](int ind) { + int blksize = 1; + for (int i = 0; i < blk.inner_nblks; i++) { + if (blk.inner_idxs[i] == ind) + blksize *= blk.inner_blks[i]; + } + return blksize; + }; + const int blksize = get_blksize(blk.inner_idxs[0]); + +# define CASE(blksize_, blk_kind) \ + do { \ + if (blksize == blksize_) { \ + typed_zero_pad_blk(mdw, data); \ + return success; \ + } \ + } while(0) + + switch (blk.inner_nblks) { + case 1: + if (blk.inner_idxs[0] == 0) { + CASE(4, a); + CASE(8, a); + CASE(16, a); + } else if (blk.inner_idxs[0] == 1) { + CASE(4, b); + CASE(8, b); + CASE(16, b); + } + break; + case 2: + case 3: + if (!IMPLICATION(blk.inner_nblks == 3, + blk.inner_idxs[0] == blk.inner_idxs[2])) + break; + + if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) { + CASE(4, ab); + CASE(8, ab); + CASE(16, ab); + } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) { + CASE(4, ba); + CASE(8, ba); + CASE(16, ba); + } + if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) { + CASE(4, bc); + CASE(8, bc); + CASE(16, bc); + } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) { + CASE(4, cb); + CASE(8, cb); + CASE(16, cb); + } + break; + default: break; + } + +# undef CASE + + // the last line of defence + typed_zero_pad_generic_blocked
(mdw, data); + return success; +} + +status_t cpu_memory_t::zero_pad() const { + memory_desc_wrapper mdw(md()); + const bool skip_zeroing = false + || data_ == nullptr + || mdw.is_zero() + || !mdw.is_blocking_desc(); + if (skip_zeroing) return success; + + switch (mdw.data_type()) { + case f32: return typed_zero_pad(); + case s32: return typed_zero_pad(); + case s8: return typed_zero_pad(); + case u8: return typed_zero_pad(); + default: assert(!"memory is undefined"); return unimplemented; + } + return unimplemented; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp new file mode 100644 index 0000000000..2c01bcc6af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_MEMORY_HPP +#define CPU_MEMORY_HPP + +#include + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "memory_desc_wrapper.hpp" + +#include "cpu_engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t: public memory_t { + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle) + : memory_t(engine, md) + , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE) + , data_((char *)handle) {} + + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md) + : cpu_memory_t(engine, md, nullptr) {} + + ~cpu_memory_t() { if (own_data_) free(data_); } + + virtual status_t init() override { + if (own_data_) { + data_ = nullptr; + const size_t size = memory_desc_wrapper(this->md()).size(); + if (size) { + data_ = (char *)malloc(size, 64); + if (data_ == nullptr) + return status::out_of_memory; + } + } + return zero_pad(); + } + + cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); } + + virtual status_t get_data_handle(void **handle) const override { + *handle = static_cast(data_); + return status::success; + } + + virtual mkldnn::impl::status_t set_data_handle(void *handle) override { + if (own_data_) { free(data_); own_data_ = false; } + data_ = static_cast(handle); + return zero_pad(); + } + + virtual mkldnn::impl::status_t zero_pad() const override; + +private: + bool own_data_; + char *data_; + + template + mkldnn::impl::status_t typed_zero_pad() const; + + cpu_memory_t(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(cpu_memory_t &&) = delete; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp new file mode 100644 index 0000000000..ac2daa415e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_POOLING_PD_HPP +#define CPU_POOLING_PD_HPP + +#include "pooling_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t { + using pooling_fwd_pd_t::pooling_fwd_pd_t; +}; + +struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t { + using pooling_bwd_pd_t::pooling_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp new file mode 100644 index 0000000000..56127f36c2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PRIMITIVE_HPP +#define CPU_PRIMITIVE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "primitive.hpp" +#include "scratchpad.hpp" + +#define CTX_IN_MEM(type, arg) static_cast(ctx.input(arg)) +#define CTX_OUT_MEM(type, arg) static_cast(ctx.output(arg)) + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t; + +struct cpu_primitive_t: public primitive_t { + cpu_primitive_t(const primitive_desc_t *pd, + bool use_global_scratchpad = false) + : primitive_t(pd) + , scratchpad_buffer_(nullptr) + , global_scratchpad_(nullptr) + { + const size_t scratchpad_size = + this->pd()->scratchpad_size(scratchpad_mode::library); + + if (scratchpad_size) { + if (use_global_scratchpad) + global_scratchpad_ = create_scratchpad(scratchpad_size); + else + scratchpad_buffer_ = malloc(scratchpad_size, 64); + } + } + + virtual ~cpu_primitive_t() { + delete global_scratchpad_; + free(scratchpad_buffer_); + } + +protected: + memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const { + void *ptr = nullptr; + if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) { + ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD); + } else { + ptr = global_scratchpad_ + ? global_scratchpad_->get() : scratchpad_buffer_; + } + + return pd()->scratchpad_registry().grantor(ptr); + } + +private: + void *scratchpad_buffer_; + scratchpad_t *global_scratchpad_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp new file mode 100644 index 0000000000..1d41ac5cea --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp @@ -0,0 +1,544 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "cpu_reducer.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void reduce_balancer_t::balance() { + using namespace nstl; + using namespace utils; + + assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0); + + const int job_complexity = 1; + + const int min_njobs_per_group = max(1, njobs_ / nthr_); + const int max_njobs_per_group = max(1, + static_cast(max_buffer_size_ / (nthr_ * job_size_))); + + /* initial guess */ + int ngroups = min(njobs_ / min_njobs_per_group, nthr_); + int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1; + int njobs_per_group_ub = div_up(njobs_, ngroups); + + /* rough upper-bound estimation, will be fixed during brute force */ + size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_; + + /* brute force parameters for the best balance... */ + for (int c_njobs_per_group = min_njobs_per_group; + c_njobs_per_group < njobs_; ++c_njobs_per_group) { + /* current assumption */ + int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_); + int c_nthr_per_group = syncable_ + ? min(nthr_ / c_ngroups, reduction_size_) : 1; + int c_njobs_per_group_ub = div_up(njobs_, c_ngroups); + + if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group) + continue; + + int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group); + size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub; + size_t c_thread_complexity_ub = c_group_size_ub * ( + job_complexity * c_thread_reduction_ub + + (c_nthr_per_group != 1)); + + if (c_thread_complexity_ub < thread_complexity_ub) { + ngroups = c_ngroups; + nthr_per_group = c_nthr_per_group; + njobs_per_group_ub = c_njobs_per_group_ub; + thread_complexity_ub = c_thread_complexity_ub; + } + } + + assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1); + assert(ngroups * nthr_per_group <= nthr_); + assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_ + || nthr_per_group == 1); /* no reduction buffer overflow */ + assert(IMPLICATION(!syncable_, nthr_per_group == 1)); + + ngroups_ = ngroups; + nthr_per_group_ = nthr_per_group; + njobs_per_group_ub_ = njobs_per_group_ub; +} + +/* reducer jit-ted driver */ + +using namespace Xbyak; + +template +struct reducer_2d_driver_t: public c_compatible { + typedef typename prec_traits::type data_t; + + reducer_2d_driver_t(int n_src, size_t src_ld, + size_t src_step, size_t dst_step, bool nullify_dst) + : n_src_(n_src), src_ld_(src_ld), src_step_(src_step) + , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {} + virtual ~reducer_2d_driver_t() {} + void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx) + { assert(ker_); ker_(dst, srcs, ny, nx); } + +protected: + int n_src_; + size_t src_ld_, src_step_, dst_step_; + bool nullify_dst_; + void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx); +}; + +template +struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + const AddressFrame &vmmword = (isa == avx2) ? yword : zword; + void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op) + { if (data_type == data_type::f32) vaddps(x1, x2, op); + else vpaddd(x1, x2, op); } + void uni_add(const Xmm& x1, const Operand& op) + { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); } + + const int vlen = cpu_isa_traits::vlen; + const int typesize + = sizeof(typename mkldnn::impl::prec_traits::type); + Xbyak::Reg64 reg_dst = abi_param1; + Xbyak::Reg64 reg_src = abi_param2; + Xbyak::Reg64 reg_ny = abi_param3; + Xbyak::Reg64 reg_nx = abi_param4; + + Xbyak::Reg64 reg_x = rax; + Xbyak::Reg64 reg_src_id = r10; + + reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step, + size_t dst_step, bool nullify_dst) + : reducer_2d_driver_t(n_src, src_ld, src_step, + dst_step, nullify_dst) + { generate(); } + + void nullify_dst(int nloads, int load_len) { + UNUSED(load_len); + for (int i = 0; i < nloads; ++i) + uni_vpxor(Vmm(i), Vmm(i), Vmm(i)); + /* prefetches[dst] ? */ + } + + void load_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(Xmm(i), ptr[reg_dst + i * load_len]); + else if (load_len == vlen) + vmovups(Vmm(i), ptr[reg_dst + i * load_len]); + else + assert(!"unsupported"); + } + } + + void store_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(ptr[reg_dst + i * load_len], Xmm(i)); + else if (load_len == vlen) + vmovups(ptr[reg_dst + i * load_len], Vmm(i)); + else + assert(!"unsupported"); + } + } + + void accumulate(int nloads, int load_len, size_t base_off) { + for (int i = 0; i < nloads; ++i) { + size_t off = base_off + i * load_len; + + if (load_len == typesize) + uni_add(Xmm(i), ptr[reg_src + off]); + else if (load_len == vlen) + uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); + else + assert(!"unsupported"); + } + } + + void loop_x() { + const int nloads[] = {cpu_isa_traits::n_vregs, 1, 1}; + const int nbranches = sizeof(nloads) / sizeof(nloads[0]); + + const int load_len[nbranches] = {vlen, vlen, typesize}; + Label loop_x_label[nbranches + 1]; + + mov(reg_x, reg_nx); + + for (int id = 0; id < nbranches; ++id) { + L(loop_x_label[id]); + + cmp(reg_x, nloads[id] * load_len[id]); + jl(loop_x_label[id + 1], T_NEAR); + + if (this->nullify_dst_) + nullify_dst(nloads[id], load_len[id]); + else + load_dst(nloads[id], load_len[id]); + + if (nloads[id] > 1) { + Label loop_srcs; + mov(reg_src_id, this->n_src_); + L(loop_srcs); + + accumulate(nloads[id], load_len[id], 0); + add(reg_src, this->src_ld_ * typesize); + + dec(reg_src_id); + jnz(loop_srcs, T_NEAR); + + sub(reg_src, this->n_src_ * this->src_ld_ * typesize); + } else { + for (int src_id = 0; src_id < this->n_src_; ++src_id) { + const size_t base_off = src_id * this->src_ld_ * typesize; + accumulate(nloads[id], load_len[id], base_off); + } + } + + store_dst(nloads[id], load_len[id]); + + add(reg_src, nloads[id] * load_len[id]); + add(reg_dst, nloads[id] * load_len[id]); + + sub(reg_x, nloads[id] * load_len[id]); + + jmp(loop_x_label[id], T_NEAR); + } + + L(loop_x_label[nbranches]); + + /* restore address registers */ + sub(reg_src, reg_nx); + sub(reg_dst, reg_nx); + } + + void generate() { + assert(isa == avx2 || isa == avx512_common || isa == avx512_mic); + + preamble(); + + shl(reg_nx, 2); + + Label ny_loop; + L(ny_loop); + + loop_x(); + + add(reg_dst, this->dst_step_ * typesize); + add(reg_src, this->src_step_ * typesize); + + dec(reg_ny); + jnz(ny_loop, T_NEAR); + + postamble(); + this->ker_ = reinterpret_castker_)>( + const_cast(this->getCode())); + } +}; + +template +inline reducer_2d_driver_t *create_reduce_2d_drv(int n_src, + size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) { + if (mayiuse(avx512_common)) + return new reducer_2d_driver_f_s_32_t(n_src, + src_ld, src_step, dst_step, nullify_dst); + else if (mayiuse(avx2)) + return new reducer_2d_driver_f_s_32_t(n_src, src_ld, + src_step, dst_step, nullify_dst); + assert(!"unimplemented"); + return nullptr; +} + +/* cpu_reducer_t */ + +template +void cpu_reducer_t::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ + * (balancer_.nthr_per_group_ - 1) + * cpu_reducer_t::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template +cpu_reducer_t::cpu_reducer_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv(balancer().nthr_per_group_ - 1, + space_per_thread(balancer()), 0, 0, false); +} + +template +cpu_reducer_t::~cpu_reducer_t() { delete drv_; } + +template +typename cpu_reducer_t::data_t * +cpu_reducer_t::get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + + /* threads 0 from each group writes directly to the destination */ + if (id_in_grp == 0) + return dst + balancer().ithr_job_off(ithr) * balancer().job_size_; + + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1) + + (id_in_grp - 1); + + auto space = scratchpad.template get(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template +void cpu_reducer_t::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + +#ifdef SIMPLE_IMPL + if (balancer().id_in_group(ithr) != 0) + return; /* only threads 0 do the reduction */ + + const int njobs_in_grp = balancer().ithr_njobs(ithr); + data_t *d = get_local_ptr(ithr, dst, scratchpad); + for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp) + { + const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad); + for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i) + d[i] += space[i]; + } +#else + using namespace utils; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const size_t cl = 64 / sizeof(data_t); + + const size_t reduction_size = njobs_in_grp * balancer().job_size_; + size_t start{0}, end{0}; + balance211(div_up(reduction_size, cl), balancer().nthr_per_group_, + id_in_grp, start, end); + + if (start == end) return; + + data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl; + const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad) + + start * cl; + const size_t len = nstl::min(end * cl, reduction_size) - start * cl; + + (*drv_)(d, space, 1, len); +#endif +} + +template struct cpu_reducer_t; +template struct cpu_reducer_t; + +/* cpu_reducer_2d_t */ + +template +void cpu_reducer_2d_t::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_ + * cpu_reducer_2d_t::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template +cpu_reducer_2d_t::cpu_reducer_2d_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv(balancer().nthr_per_group_, + space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_, + true); +} + +template +cpu_reducer_2d_t::~cpu_reducer_2d_t() { delete drv_; } + +template +typename cpu_reducer_2d_t::data_t *cpu_reducer_2d_t:: +get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp; + auto space = scratchpad.template get(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template +int cpu_reducer_2d_t::choose_x_blocking(int nx, int ny, + int nthr_per_grp) const { + // find x_blocking for better balance reducing work between threads + assert(conf_.x_block_ > 0 && nx > conf_.x_block_ + && nx % conf_.x_block_ == 0); + int x_blocking = nx / conf_.x_block_; + int min_x_blocking = + utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny)); + while (true) { + if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2) + x_blocking /= 2; + else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3) + x_blocking /= 3; + else + break; + } + if (x_blocking >= min_x_blocking * 4) x_blocking = 1; + x_blocking *= conf_.x_block_; + return x_blocking; +} + +template +void cpu_reducer_2d_t::reduce_block(const data_t* space_base, + data_t *dst, int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const { + data_t *d = dst + (start_y + ny_start) * conf_.dst_x_ + + start_x + nx_start; + const data_t *space = space_base + job * balancer().job_size_ + + ny_start * conf_.job_size_x_ + nx_start; +#ifdef SIMPLE_IMPL + for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) { + const data_t *w = &space[idg * space_per_thread(balancer())]; + for (int y = 0; y < ny_step; ++y) + for (int x = 0; x < nx_step; ++x) { + d[y * conf_.dst_x_ + x] + = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x]) + + w[y * conf_.job_size_x_ + x]; + } + } +#else + (*drv_)(d, space, ny_step, nx_step); +#endif +} + +template +void cpu_reducer_2d_t::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_); + const int global_job_start = balancer().ithr_job_off(ithr); + + const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad); + + const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_); + const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps; + + if (id_in_grp >= pr_grps * pr_nthr_per_grp) + return; /* idle */ + + const int pr_my_grp = id_in_grp / pr_nthr_per_grp; + const int pr_my_id = id_in_grp % pr_nthr_per_grp; + + int pr_job_start{0}, pr_job_end{0}; + balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end); + + for (int j = pr_job_start; j < pr_job_end; ++j) { + const int global_job = global_job_start + j; + const int j_y = global_job / njobs_x; + const int j_x = global_job % njobs_x; + const int start_y = j_y * conf_.job_size_y_; + const int start_x = j_x * conf_.job_size_x_; + const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_); + const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_); + int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp); + + int nxy_start{0}, nxy_end{0}; + balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id, + nxy_start, nxy_end); + if (nxy_start == nxy_end) continue; + nxy_start *= x_blocking; + nxy_end *= x_blocking; + + int nxy = nxy_start; + if (nxy % nx != 0) { + int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy); + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nx_step); + nxy += nx_step; + } + if ((nxy_end - nxy) > nx) { + int ny_step = (nxy_end - nxy) / nx; + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, ny_step, nx); + nxy += nx * ny_step; + } + if ((nxy_end - nxy) > 0) { + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nxy_end - nxy); + } + } +} + +template struct cpu_reducer_2d_t; +template struct cpu_reducer_2d_t; + +/* accumulator section */ + +template +cpu_accumulator_1d_t::cpu_accumulator_1d_t(): drv_(nullptr) { + drv_ = create_reduce_2d_drv(1, 0, 0, 0, false); +} + +template +cpu_accumulator_1d_t::~cpu_accumulator_1d_t() { + delete drv_; +} + +template +void cpu_accumulator_1d_t::accumulate(data_t *dst, + const data_t *src, size_t size) { + (*drv_)(dst, src, 1, size); +} + +template struct cpu_accumulator_1d_t; +template struct cpu_accumulator_1d_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp new file mode 100644 index 0000000000..27f5939cd2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp @@ -0,0 +1,334 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REDUCER_HPP +#define CPU_REDUCER_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/** class to perform balancing over 3D array + * + * Conceptually the reduction happens according to the picture below: + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ result + * + * In a simple case the result must be contiguous in memory. + * @class cpu_reducer_t is an implementation. + * + * Threads are divided into groups. The groups are independent of each other. + * Each group may work on several jobs (the distribution is not uniform, since + * njobs might be not a multiple of groups). Threads within a group work on + * different parts of the reduction dimension. Thread 0 in each group is called + * master (@sa reduce_balancer_t::master()). + * + * If threading driver does not allow sync between sub-group of threads (e.g. + * Intel(R) TBB) the # of thread per group is enforced to be 1. + */ +struct reduce_balancer_t { + reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */ + reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size, + size_t max_buffer_size) + { init(nthr, job_size, njobs, reduction_size, max_buffer_size); } + + reduce_balancer_t &init(int nthr, int job_size, int njobs, + int reduction_size, size_t max_buffer_size) + { + syncable_ = mkldnn_thr_syncable(); + nthr_ = nthr; + job_size_ = job_size; + njobs_ = njobs; + reduction_size_ = reduction_size; + max_buffer_size_ = max_buffer_size; + balance(); + return *this; + } + + bool syncable_; + int nthr_; + int job_size_, njobs_, reduction_size_; + + int ngroups_; /** number of independent work (thread) groups */ + int nthr_per_group_; /** number of threads within a single work group */ + int njobs_per_group_ub_; /** the max # of jobs within a work group */ + + bool master(int ithr) const { return id_in_group(ithr) == 0; } + bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; } + + int group_id(int ithr) const { return ithr / nthr_per_group_; } + int id_in_group(int ithr) const { return ithr % nthr_per_group_; } + + int grp_njobs(int grp) const { + if (grp >= ngroups_) return 0; + return njobs_ / ngroups_ + (grp < njobs_ % ngroups_); + } + int grp_job_off(int grp) const { + if (grp >= ngroups_) return njobs_; + return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_); + } + + int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); } + int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); } + +private: + size_t max_buffer_size_; + void balance(); +}; + +/** forward declaration of reduce driver */ +template struct reducer_2d_driver_t; + +/** class to perform a reduction over 3D array + * + * Balancing is based on @class reduce_balancer_t. + * Restrictions: the result of the reduction must be contiguous in memory. * + * The reduction happens according to the picture below (once more): + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ (contiguous) result + * + * An example how work might be shared is shown below. + * + * In this example group 0 owns 2 (independent) jobs -- 2 big squares. + * The number of threads per group is also 2 (thread 0 of group 0 and thread 1 + * of group 0). Master threads (i.e. threads with id 0 in corresponding group) + * from each group put the partial result directly into destination memory, + * while all the other threads with-in the group use workspace (on the picture + * the only thread 1). Once intermediate results obtained each group reduces + * corresponding part (own jobs) to the destination memory. + * + * <------- group 0 -------> + * + * +-----------+ +-----------+ ^ + * | | | | | thread 0 of reduces to the dest-memory + * | | | | | group 0 +-----------+ +-----------+ + * |- - - - - -| |- - - - - -| X + * | | | | | thread 1 of reduces to workspace[tid=1]: + * | | | | | group 0 +-----------+ +-----------+ + * +-----------+ +-----------+ v + * | | | | | | + * v v v v v v + * ((barrier)) ============================= + * + * dest-memory: +-----------+ +-----------+ + */ +template +struct cpu_reducer_t { + typedef typename prec_traits::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer) + { balancer_ = balancer; return *this; } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + }; + + cpu_reducer_t(const conf_t &conf); + ~cpu_reducer_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results. + * Reduction destination @p dst must be provided as well (master threads + * from each group will use it for partial result to reduce memory + * pressure). + * + * @note: job offset is already applied by get_local_ptr(), which means all + * threads should start writing from the very beginning of returned + * address. + */ + data_t *get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t *drv_; + + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +template +struct cpu_reducer_2d_t { + typedef typename prec_traits::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer, int job_size_x, + int job_size_y, int x_block, int dst_x, int dst_y) { + balancer_ = balancer; + job_size_x_ = job_size_x; + job_size_y_ = job_size_y; + x_block_ = x_block; + dst_x_ = dst_x; + dst_y_ = dst_y; + return *this; + } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_; + }; + + cpu_reducer_2d_t(const conf_t &conf); + ~cpu_reducer_2d_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results */ + data_t *get_local_ptr(int ithr, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t *drv_; + + int choose_x_blocking(int nx, int ny, int nthr_per_grp) const; + void reduce_block(const data_t* space_base, data_t *dst, + int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const; + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +/** simple 1d accumulator: y[:] += x[:] */ +template +struct cpu_accumulator_1d_t { + typedef typename prec_traits::type data_t; + + cpu_accumulator_1d_t(); + ~cpu_accumulator_1d_t(); + void accumulate(data_t *dst, const data_t *src, size_t size); + + reducer_2d_driver_t *drv_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp new file mode 100644 index 0000000000..82be70353d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp @@ -0,0 +1,262 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "cpu_engine.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_memory.hpp" +#include "type_helpers.hpp" + +#include "cpu/jit_uni_reorder.hpp" +#include "cpu/simple_reorder.hpp" +#include "cpu/wino_reorder.hpp" +#include "cpu/rnn/rnn_reorders.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; + +#define REG_SR(idt, ifmt, odt, ofmt, ...) \ + simple_reorder_t::pd_t::create + +#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) + +#define REG_SR_DIRECT_COPY(idt, odt) \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) + +static const rpd_create_f cpu_reorder_impl_list[] = { + /* winograd */ + wino_reorder_t::pd_t::create, + //wino_reorder_t::pd_t::create, + + /* rnn reorders */ + rnn_data_reorder_t::pd_t::create, + rnn_weights_reorder_t::pd_t::create, + rnn_weights_reorder_t::pd_t::create, + + /* conv reorders w/ compensation */ + REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + + /* regular reorders */ + +#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) + /* Direct copy for icc which is faster than jitted code; + * Direct copy for gcc which might or might not be faster than jitted + * code, but still worth it because doesn't require jitting, i.e. much + * faster creation time. This is tentative solution and should be removed + * later (when we will cache jitted code?...). */ + REG_SR_DIRECT_COPY(f32, f32), +#endif + +#ifdef __INTEL_COMPILER + /* direct copy for icc, which is faster than jitted code */ + /* + REG_SR_DIRECT_COPY(f32, s32), + REG_SR_DIRECT_COPY(f32, s8), + REG_SR_DIRECT_COPY(f32, u8), + REG_SR_DIRECT_COPY(s32, f32), + REG_SR_DIRECT_COPY(s32, s32), + REG_SR_DIRECT_COPY(s32, s8), + REG_SR_DIRECT_COPY(s32, u8), + REG_SR_DIRECT_COPY(s8, f32), + REG_SR_DIRECT_COPY(s8, s32), + REG_SR_DIRECT_COPY(s8, s8), + REG_SR_DIRECT_COPY(s8, u8), + REG_SR_DIRECT_COPY(u8, f32), + REG_SR_DIRECT_COPY(u8, s32), + REG_SR_DIRECT_COPY(u8, s8), + REG_SR_DIRECT_COPY(u8, u8), + */ +#endif + + /* jit */ + jit_uni_reorder_create, + + /* fp32: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, f32, nCw4c), + REG_SR_BIDIR(f32, any, f32, nCw8c), + REG_SR_BIDIR(f32, any, f32, OIw4i4o), + REG_SR_BIDIR(f32, any, f32, OIw8i8o), + REG_SR_BIDIR(f32, any, f32, OIw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCw16c), + REG_SR_BIDIR(f32, any, f32, OIw16o16i), + REG_SR_BIDIR(f32, any, f32, OIw16i16o), + REG_SR_BIDIR(f32, any, f32, IOw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOw16o16i), + + REG_SR_BIDIR(f32, any, f32, nChw4c), + REG_SR_BIDIR(f32, any, f32, nChw8c), + REG_SR_BIDIR(f32, any, f32, OIhw4i4o), + REG_SR_BIDIR(f32, any, f32, Ohwi8o), + + REG_SR_BIDIR(f32, any, f32, OIhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), + REG_SR_BIDIR(f32, any, f32, gOhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nChw16c), + REG_SR_BIDIR(f32, any, f32, Oihw4o), + REG_SR_BIDIR(f32, any, f32, Oihw16o), + REG_SR_BIDIR(f32, any, f32, Ohwi4o), + REG_SR_BIDIR(f32, any, f32, Ohwi16o), + REG_SR_BIDIR(f32, any, f32, OIhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIhw16i16o), + REG_SR_BIDIR(f32, any, f32, IOhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOihw4o), + REG_SR_BIDIR(f32, any, f32, gOihw16o), + REG_SR_BIDIR(f32, any, f32, gOhwi4o), + REG_SR_BIDIR(f32, any, f32, gOhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), + + REG_SR_BIDIR(f32, any, f32, nCdhw4c), + REG_SR_BIDIR(f32, any, f32, nCdhw8c), + REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, Odhwi8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOdhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCdhw16c), + REG_SR_BIDIR(f32, any, f32, Oidhw4o), + REG_SR_BIDIR(f32, any, f32, Oidhw16o), + REG_SR_BIDIR(f32, any, f32, Odhwi16o), + REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), + REG_SR_BIDIR(f32, any, f32, gOidhw4o), + REG_SR_BIDIR(f32, any, f32, gOidhw16o), + REG_SR_BIDIR(f32, any, f32, gOdhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), + */ + + /* fp32: blocked <-> blocked with tail */ + REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), + REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), + REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), + + /* int: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, s32, nChw16c), + REG_SR_BIDIR(f32, any, s8, nChw16c), + REG_SR_BIDIR(f32, any, u8, nChw16c), + REG_SR_BIDIR(s32, any, f32, nChw16c), + REG_SR_BIDIR(s32, any, s32, nChw16c), + REG_SR_BIDIR(s32, any, s8, nChw16c), + REG_SR_BIDIR(s32, any, u8, nChw16c), + REG_SR_BIDIR(s8, any, f32, nChw16c), + REG_SR_BIDIR(s8, any, s32, nChw16c), + REG_SR_BIDIR(s8, any, s8, nChw16c), + REG_SR_BIDIR(s8, any, u8, nChw16c), + REG_SR_BIDIR(u8, any, f32, nChw16c), + REG_SR_BIDIR(u8, any, s32, nChw16c), + REG_SR_BIDIR(u8, any, s8, nChw16c), + REG_SR_BIDIR(u8, any, u8, nChw16c), + + REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), + */ + + /* reference: the last line of defence */ + /* + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), + */ + + /* eol */ + nullptr, +}; +} + +const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { + return cpu_reorder_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp new file mode 100644 index 0000000000..1622eb6849 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REORDER_PD_HPP +#define CPU_REORDER_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "reorder_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_reorder_pd_t: public reorder_pd_t { + using reorder_pd_t::reorder_pd_t; + + status_t init() { + const auto &post_ops = attr()->post_ops_; + bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1 + && post_ops.entry_[0].kind == primitive_kind::sum); + scratchpad_engine_ = src_engine_; + return args_ok ? status::success : status::unimplemented; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp new file mode 100644 index 0000000000..f16587b99f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SHUFFLE_PD_HPP +#define CPU_SHUFFLE_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "shuffle_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_shuffle_pd_t: public shuffle_pd_t { + using shuffle_pd_t::shuffle_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp new file mode 100644 index 0000000000..3a39eab974 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SOFTMAX_PD_HPP +#define CPU_SOFTMAX_PD_HPP + +#include + +#include "c_types_map.hpp" +#include "softmax_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t { + using softmax_fwd_pd_t::softmax_fwd_pd_t; +}; + +struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t { + using softmax_bwd_pd_t::softmax_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp new file mode 100644 index 0000000000..1ab5d9f174 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_sum.hpp" +#include "cpu/simple_sum.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const spd_create_f cpu_sum_impl_list[] = { + /* + INSTANCE(simple_sum_t), + INSTANCE(ref_sum_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const spd_create_f *cpu_engine_t::get_sum_implementation_list() const { + return cpu_sum_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp new file mode 100644 index 0000000000..0965129f9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SUM_PD_HPP +#define CPU_SUM_PD_HPP + +#include "c_types_map.hpp" +#include "sum_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_sum_pd_t: public sum_pd_t { + using sum_pd_t::sum_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp new file mode 100644 index 0000000000..a9810dec28 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp @@ -0,0 +1,372 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "gemm_utils_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace gemm_utils { +#define BM_NOCOPY_AVX 64 +#define BN_NOCOPY_AVX 48 +#define BK_NOCOPY_AVX 384 +#define BN_LARGE_NOCOPY_AVX 192 +#define BM_SMALL_NOCOPY_AVX 16 +#define BN_SMALL_NOCOPY_AVX 1 +#define BK_SMALL_NOCOPY_AVX 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + + nthr = nthrs; + nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX; + nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX; + nthr_k = 1; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + int nthr_other = nthr_k = 1; + while ((nthr_m * nthr_n * nthr_other < nthr) + && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) { + nthr_other++; + if ((nthr / nthr_other) * nthr_other > 0.9 * nthr) + nthr_k = nthr_other; + } + } + nthr /= nthr_k; + + if (nthr_m == 1) + nthr_n = nthr; + if (nthr_n == 1) + nthr_m = nthr; + + // Simple partition reduction + while (nthr_m * nthr_n > nthr) + if (nthr_m > nthr_n) + nthr_m--; + else + nthr_n--; + while (nthr_m * nthr_n < nthr) + if (nthr_m < nthr_n) + nthr_m++; + else + nthr_n++; + + if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX) + nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX) + nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX +#undef BN_NOCOPY_AVX +#undef BK_NOCOPY_AVX +#undef BN_LARGE_NOCOPY_AVX +#undef BM_SMALL_NOCOPY_AVX +#undef BN_SMALL_NOCOPY_AVX +#undef BK_SMALL_NOCOPY_AVX + +#define BM_NOCOPY_AVX512_COMMON 32 +#define BN_NOCOPY_AVX512_COMMON 64 +#define BK_NOCOPY_AVX512_COMMON 192 +#define BN_LARGE_NOCOPY_AVX512_COMMON 192 +#define BM_SMALL_NOCOPY_AVX512_COMMON 16 +#define BN_SMALL_NOCOPY_AVX512_COMMON 1 +#define BK_SMALL_NOCOPY_AVX512_COMMON 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k = 1; + int MB, NB, KB; + nthr = nthrs; + + int counter = 0; + float ratio_float = 1.; + int ratio = 1; + nthr = nthrs; + int nthr_m_gt_n; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + if (n <= 2 * BN_NOCOPY_AVX512_COMMON && + m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) { + nthr_k = k / BK_NOCOPY_AVX512_COMMON; + if (nthr_k > nthr / 4) + nthr_k = nthr / 4; + if (nthr_k < 1) + nthr_k = 1; + + while ((nthr_k > 1) && (nthr % nthr_k)) { + nthr_k--; + } + nthr /= nthr_k; + } else { + nthr_k = 1; + } + } + nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON; + nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON; + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0; + ratio_float = (float)nthr_m / nthr_n; + + if (nthr_m_gt_n) + ratio = (int)ratio_float; + else + ratio = (int)(1. / ratio_float); + + // scale down nthr_m and nthr_n if they are too large + while (nthr_m * nthr_n > 4 * nthr) { + nthr_m /= 2; + nthr_n /= 2; + } + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + // Simple partition reduction + counter = 0; + while (nthr_m * nthr_n > nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m--; + else { + nthr_n--; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n--; + else { + nthr_m--; + counter = -1; + } + } + counter++; + } + + // Simple partition increment + counter = 0; + while (nthr_m * nthr_n < 0.95 * nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m++; + else { + nthr_n++; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n++; + else { + nthr_m++; + counter = -1; + } + } + counter++; + } + + // if nothing works out, then this should work + if ((nthr_m * nthr_n > nthr)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON) + nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON) + nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX512_COMMON +#undef BN_NOCOPY_AVX512_COMMON +#undef BK_NOCOPY_AVX512_COMMON +#undef BN_LARGE_NOCOPY_AVX512_COMMON +#undef BM_SMALL_NOCOPY_AVX512_COMMON +#undef BN_SMALL_NOCOPY_AVX512_COMMON +#undef BK_SMALL_NOCOPY_AVX512_COMMON + +// Partition n values as equally as possible among nthr threads +// and set the offset (t_offset) and number of values (t_block) for ithr +// Assumption: 0 <= ithr < nthr +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block) +{ + int band = n / nthr; + if (band == 0) + band = 1; + int tail = n - band * nthr; + if (tail < 0) + tail = 0; + + if (ithr < tail) { + band++; + *t_offset = band * ithr; + *t_block = band; + } else { + *t_offset = band * ithr + tail; + *t_block = band; + } + + if (*t_offset >= n) { + *t_offset = 0; + *t_block = 0; + } + + if (*t_offset + *t_block > n) { + *t_block = n - *t_offset; + } +} + +// Sum the m*n values from p_src into p_dst, assuming the two-dimensional +// arrays have leading dimensions ld_src and ld_dst, respectively +template +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst) +{ + int i, j; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + p_dst[i + j * ld_dst] += p_src[i + j * ld_src]; + } + } +} + +template +void sum_two_matrices(int m, int n, + float * __restrict p_src, dim_t ld_src, + float * __restrict p_dst, dim_t ld_dst); + +template +void sum_two_matrices(int m, int n, + double * __restrict p_src, dim_t ld_src, + double * __restrict p_dst, dim_t ld_dst); +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp new file mode 100644 index 0000000000..3352298b4a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_UTILS_HPP +#define GEMM_UTILS_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace gemm_utils { +// Alias for any dimension related variable. +typedef ptrdiff_t dim_t; + +template +struct gemm_traits {}; + +template +struct gemm_traits { + static constexpr int m = 8; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 192; + static constexpr int BK = isTransB ? 96 : 512; +}; + +template +struct gemm_traits { + static constexpr int m = 16; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 48; + static constexpr int BK = isTransB ? 96 : 256; +}; + +template +using unroll_factor = gemm_traits; + +template +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst); + +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK); + +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK); + +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp new file mode 100644 index 0000000000..d7be43e392 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp @@ -0,0 +1,2131 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx512_common_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#ifdef _WIN32 +#define STACK_K_CAPACITY 32 +#else +#define STACK_K_CAPACITY 2048 +#endif +#define SIZE 4 +#define OFFSET 128 +#define BASE_SHIFT 2 +#define SECOND_FETCH unroll_n +#define UNROLL_M 48 +#define UNROLL_N 8 + +namespace avx512_common_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + enum { ver_avx512_core, ver_avx512_mic } ver = + mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto ORIG_SP = qword[rsp + 120]; + + auto ZSTRIDE = zmm4; + auto VALPHA = zmm6; + auto VBETA = zmm7; + auto VBIAS1 = zmm1; + auto VBIAS2 = zmm2; + auto VBIAS3 = zmm3; + + auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80; + auto PREFETCHSIZEB = 16; + + Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15, + zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, + zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 }; + + // Function for packing if needed + auto do_pack = [&](int unroll_m) { + Label pack2, pack3, pack4, pack10; + + mov(BO1, A); + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE] + | k1, + zmm0); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE] + | k2, + zmm1); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE] + | k3, + zmm2); + } + } else { + for (int i = 0; i < 4; i++) { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, + ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + add(BO1, 4 * SIZE); + } + add(AO1, unroll_m * 4 * SIZE); + + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + } else { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + add(BO1, SIZE); + } + + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE], + zmm2 | k3); + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Function to update C, covering masking and other considerations + auto update = [&](Zmm reg, bool useCO1, int offset, int mask, + bool useScale = false) { + vmulps(reg, reg, VALPHA); + if (!isBeta0) { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(zmm0, reg, zmm0); + } else { + vfmadd132ps(zmm0, reg, VBETA); + } + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3); + break; + } + } + } else { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg); + else + vmovups(ptr[CO2 + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3); + break; + } + } + } + vpxorq(reg, reg, reg); + }; + + // Loop with unroll_n - 2 FMAs; called by innerkernel + auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) { + for (int i = 2; i < unroll_n; i++) { + if (ver == ver_avx512_core) { + if (!isTransB) { + switch (i) { + case 2: + vbroadcastss( + zmm3, + ptr[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, + ptr[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, + ptr[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, + ptr[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } else { + if (!isTransB) { + switch (i) { + case 2: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vfmadd231ps( + regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } + } + } + }; + + // Innerkernel; called by kernel + auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool doCPrefetch, bool isUnmasked = true) { + for (int i = 0; i < 8; i++) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]); + } + + if (!isDirect) { + if (i != 0) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } + } + } + } else { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[0], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + } + } + + if (unroll_n >= i + 1) { + if (!isTransB) { + switch (i) { + case 0: + prefetcht0( + ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 1: + prefetcht0(ptr[BO1 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 2: + prefetcht0(ptr[BO1 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 3: + prefetcht0(ptr[BO1 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 4: + prefetcht0( + ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 5: + prefetcht0(ptr[BO2 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 6: + prefetcht0(ptr[BO2 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 7: + prefetcht0(ptr[BO2 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + } + } + } + + if (unroll_n >= 2) { + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, + ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[1], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + } + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + if (i == 7) + sub(LDA4, -unroll_m * 8 * SIZE); + } + fmaloop(unroll_m, unroll_n, i); + + if (i == 1) { + if (doCPrefetch) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 0 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 0 * 16 * SIZE]); + } + } + if (i == 3) { + if (doCPrefetch && unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 1 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 1 * 16 * SIZE]); + } + if (!isTransA) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 0 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 0 * SIZE]); + } + } + if (i == 5) { + if (doCPrefetch) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 2 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 2 * 16 * SIZE]); + } + add(CO2, LDC); + } + if (!isTransA) { + if (unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 1 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 1 * SIZE]); + } + } + } + + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + } // end of for loop + + if (!isTransB) { + sub(BO1, -8 * SIZE); + if (unroll_n >= 4) + sub(BO2, -8 * SIZE); + } + if (!isTransA) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 2 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 2 * SIZE]); + } + lea(AA, ptr[AA + LDA]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * 8 * SIZE); + } + + sub(LL, 1); + }; + + // Main kernel; does prefetching and calls innerkernel + // After calculating results in registers, writes back to C matrix by + // calling update + auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool isUnmasked = true) { + if (!isDirect) { + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + auto step = ver == ver_avx512_core ? 2 : 4; + lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } + } + } + + Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; + + mov(LL, K); + sar(LL, 3); + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel( + unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); + add(LL, unroll_n); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + mov(LL, K); + and_(LL, 7); + jle(kernel18, T_NEAR); + align(16); + + L(kernel16); + if (isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + for (int i = 0; i < unroll_n; i++) { + if (!isTransB) { + switch (i) { + case 0: + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + break; + case 1: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 2: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) { + vfmadd231ps(regs[i + 8], zmm3, zmm1); + } + if (unroll_m >= 48) { + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + sub(LL, 1); + jg(kernel16, T_NEAR); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta cases need to be handled + if (hasBias) { + mov(BIAS1, BIAS); + if (isUnmasked || unroll_m > 16) + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + else + vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]); + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) + vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]); + else + vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]); + } + if (unroll_m >= 48) { + if (isUnmasked) + vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]); + else + vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + bool useScale = i % 2 != 0; + bool useCO1 = i < 2; + if (i == 2) + lea(CO2, ptr[CO1 + LDC * 2]); + if (i == 4 || i == 6) + lea(CO2, ptr[CO2 + LDC * 2]); + if (hasBias) + vaddps(regs[i], VBIAS1, regs[i]); + if (isUnmasked || unroll_m > 16) { + update(regs[i], useCO1, 0, 0, useScale); + } else { + update(regs[i], useCO1, 0, 1, useScale); + } + if (unroll_m >= 32) { + if (hasBias) + vaddps(regs[i + 8], VBIAS2, regs[i + 8]); + if (isUnmasked || unroll_m > 32) { + update(regs[i + 8], useCO1, 16, 0, useScale); + } else { + update(regs[i + 8], useCO1, 16, 2, useScale); + } + } + if (unroll_m >= 48) { + if (hasBias) + vaddps(regs[i + 16], VBIAS3, regs[i + 16]); + if (isUnmasked) { + update(regs[i + 16], useCO1, 32, 0, useScale); + } else { + update(regs[i + 16], useCO1, 32, 3, useScale); + } + } + } + + switch (unroll_n) { + case 1: add(CO1, LDC); break; + case 2: lea(CO1, ptr[CO1 + LDC * 2]); break; + case 3: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 4: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 5: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 6: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 7: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 8: lea(CO1, ptr[CO2 + LDC * 2]); break; + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + case 7: + lea(BO1, ptr[BO1 + LDB * 8]); + sub(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 8]); + sub(BO2, LDB); + break; + case 8: + lea(BO1, ptr[BO1 + LDB * 8]); + lea(BO2, ptr[BO2 + LDB * 8]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 48 rows, 8 columns at a time (handling tail + // cases appropriately by doing 32 or 16 rows, and/or with masking, + // and/or fewer columns). + auto subloop = [&](int unroll_m) { + Label l_subloop_20x[8], l_subloop_mask_20x[8]; + Label l_subloop_30x[8], l_subloop_mask_30x[8]; + + Label subloop11, subloop11mask; + Label subloop30, subloop30mask; + Label subloop31, subloop31mask; + Label subloop96; + Label subloop98, subloop98mask; + Label subloop99; + + // Create mask + mov(BO1, rcx); + mov(rcx, M); + sub(rcx, unroll_m - 16); + mov(CO1, 16); + cmp(rcx, 16); + + cmovg(rcx, CO1); + mov(rax, 1); + sal(rax, cl); + sub(rax, 1); + mov(rcx, 0xffff); + + if (unroll_m == 16) { + kmovw(k1, eax); + } else if (unroll_m == 32) { + kmovw(k1, ecx); + kmovw(k2, eax); + } else { + kmovw(k1, ecx); + kmovw(k2, ecx); + kmovw(k3, eax); + } + mov(rcx, BO1); + + and_(rax, 0xffff); + cmp(rax, 0xffff); + jne(subloop96, T_NEAR); + + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true); + } else { + kernel(unroll_m, UNROLL_N, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + align(16); + + L(subloop11); + kernel(unroll_m, UNROLL_N, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(l_subloop_30x[1], T_NEAR); + align(16); + + L(subloop31); + kernel(unroll_m, UNROLL_N, true, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop96); + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98mask, T_NEAR); + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98mask); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30mask, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30mask, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true, false); + } else { + kernel(unroll_m, UNROLL_N, false, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + align(16); + + L(subloop11mask); + kernel(unroll_m, UNROLL_N, false, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30mask); + cmp(I, UNROLL_N); + jl(l_subloop_mask_30x[1], T_NEAR); + align(16); + + L(subloop31mask); + kernel(unroll_m, UNROLL_N, true, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + imul(rax, rax, 0x30); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + if (isTransA) { + vpbroadcastq(zmm2, LDA); + vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE); + mov(rax, -2); + kmovw(k4, eax); + + for (int i = 0; i < 6; i++) { + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + kshiftlw(k4, k4, 1); + } + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f); + mov(FLAG, rax); + + for (int i = 8; i < 16; i++) { + for (int j = 0; j < 3; j++) { + vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j)); + } + } + + Label main0, main1, main2, main999; + + cmp(M, 32); + jle(main0, T_NEAR); + align(16); + + L(main1); + subloop(48); + sub(M, UNROLL_M); + cmp(M, 32); + jg(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 16); + jle(main2, T_NEAR); + + subloop(32); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 0); + jle(main999, T_NEAR); + subloop(16); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032, BN, BK; + if (mayiuse(avx512_core)) { + BN = isTransA ? 384 : 64; + BK = 384; + } else { + BN = isTransA ? 96 : 64; + BK = isTransB ? 96 : 192; + if (!isTransA && !isTransB) + BK = 128; + } + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx512_common_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx512_common( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 48 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp new file mode 100644 index 0000000000..d581b7fd71 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP +#define JIT_AVX512_COMMON_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp new file mode 100644 index 0000000000..60d4220837 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp @@ -0,0 +1,2705 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#if _WIN32 +#define STACK_K_CAPACITY 128 +#else +#define STACK_K_CAPACITY 8192 +#endif +#define SIZE 4 +#define OFFSET 32 +#define BASE_SHIFT 2 +#define SECOND_FETCH 14 + +namespace avx_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + const bool is_avx2 = mayiuse(avx2); + assert(IMPLICATION(!is_avx2, mayiuse(avx))); + + const int UNROLL_M = is_avx2 ? 16 : 8; + const int UNROLL_N = 6; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto MASK = dword[rsp + 88]; + auto STRIDE = qword[rsp + 120]; + auto ORIG_SP = qword[rsp + 152]; + + auto VALPHA = ymm1; + auto VBETA = ymm2; + auto VMASK = ymm3; + auto VBIAS1 = ymm2; + auto VBIAS2 = ymm4; + + auto PREFETCHSIZEA = 128; + auto PREFETCHSIZEB = (!isTransB) ? -16 : 0; + + // Function for packing if needed + auto do_pack = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + Label pack2, pack3, pack4, pack10; + + int regIdx; + Reg64 reg; + + mov(BO1, A); + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + + if (isTransA) { + lea(BO2, ptr[BO1 + LDA * 4]); + lea(CO1, ptr[LDA + LDA * 2]); + vmovupd(ymm7, STRIDE); + } + + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + regIdx = (i % 2 == 0) ? 4 : 6; + if (isLoad1Unmasked) { + vmovups(Ymm(regIdx), + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx), VMASK, + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(Ymm(regIdx + 1), + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx + 1), VMASK, + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + Ymm(regIdx)); + if (unroll_m > 8) { + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + Ymm(regIdx + 1)); + } + } + + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + i * 4 - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + i * 4 - OFFSET) + * SIZE], + xmm1); + } + } else if (is_avx2) { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 0 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 0 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO1 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 1 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 1 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovups(Xmm(ld_step % 2), ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + for (int off = 0; off < 4; ++off) + pextrd(ptr[dst_addr + unroll_m * off * SIZE], + Xmm(ld_step % 2), off); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + if (i == 0) + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + } + } else { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 2 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 2 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 3 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 3 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } + } + add(BO1, (4 * SIZE)); + } + + add(AO1, unroll_m * 4 * SIZE); + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + if (isLoad1Unmasked) { + vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm5, VMASK, + ptr[BO1 + (1 + 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm4); + if (unroll_m > 8) { + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm5); + } + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else if (is_avx2) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovss(xmm1, ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + movss(ptr[dst_addr], xmm1); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE], + xmm1); + + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE], + xmm1); + } + add(BO1, SIZE); + } + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Fused multiply add; may become one or two instructions + auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2, + bool overWrite = false) { + if (useFma) { + if (is_avx2) { + vfmadd231ps(reg2, reg1, reg0); + } else { + assert(UNROLL_M == 8); + auto tent_vreg = overWrite ? reg1 : ymm1; + vmulps(tent_vreg, reg1, reg0); + vaddps(reg2, reg2, tent_vreg); + } + } else { + if (!overWrite) { + vmulps(ymm15, reg1, reg0); + vaddps(reg2, reg2, ymm15); + } else { + vmulps(reg1, reg1, reg0); + vaddps(reg2, reg2, reg1); + } + } + }; + + // Inner kernel with k=8 + auto innerkernel8 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 8; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 7) { + sub(LDA4, -unroll_m * 8 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3 || i == 4 || i == 5 || i == 6) { + if (unroll_m >= 16) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 7) { + if (!isTransB) { + if (unroll_n >= 4) { + sub(BO2, -8 * SIZE); + } + } + if (!isTransA) { + prefetcht2(ptr[AA]); + lea(AA, ptr[AA + LDA]); + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 8 * SIZE); + } + sub(LL, 1); + + }; + + // Inner kernel with k=4 + auto innerkernel4 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 4; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 3) { + sub(LDA4, -unroll_m * 4 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3) { + if (!isTransB) { + sub(BO1, -4 * SIZE); + if (unroll_n >= 4) { + sub(BO2, -4 * SIZE); + } + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 4 * SIZE); + } + + }; + + // Inner kernel with k=2 + auto innerkernel2 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + for (int i = 0; i < 2; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + } + + }; + + // Inner kernel with k=1 + auto innerkernel1 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) { + + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg00); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg06); + } + + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg01); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg07); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg02); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg08); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg03); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg09); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg04); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg10); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg05); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg11); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + }; + + // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as + // appropriate + // After calculating results in registers, writes back to C matrix + auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma, + Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6), + Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9), + Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12), + Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15), + Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6), + Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9), + Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12), + Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) { + if (!isDirect) { + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDB * 2]); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } + } + } + + for (int i = 4; i < 10; i++) { + vxorps(Ymm(i), Ymm(i), Ymm(i)); + vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6)); + } + + mov(LL, K); + sar(LL, 3); + + Label kernel12, kernel13, kernel14, kernel15; + Label kernel16, kernel17, kernel18; + + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 2) + prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 3) + prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 4) + prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 5) + prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 6) + prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]); + + add(LL, SECOND_FETCH); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + test(K, 4); + jle(kernel16, T_NEAR); + innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + + L(kernel16); + test(K, 2); + jle(kernel17, T_NEAR); + innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + align(16); + + L(kernel17); + if (unroll_m == 16) { + if (unroll_n <= 3) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg06, reg06, reg18); + vaddps(reg07, reg07, reg19); + vaddps(reg08, reg08, reg20); + } + } + + if (unroll_m <= 8) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg03, reg03, reg15); + vaddps(reg04, reg04, reg16); + vaddps(reg05, reg05, reg17); + } + + test(K, 1); + jle(kernel18, T_NEAR); + innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta and bias cases need to be + // handled + switch (unroll_n) { + case 1: mov(rax, LDC); break; + case 2: lea(rax, ptr[LDC * 2]); break; + case 3: lea(rax, ptr[LDC + LDC * 2]); break; + case 4: lea(rax, ptr[LDC + LDC * 4]); break; + case 5: + lea(rax, ptr[LDC * 4]); + add(rax, LDC); + break; + case 6: + lea(rax, ptr[LDC + LDC * 2]); + add(rax, rax); + break; + } + + if (hasBias) { + mov(BIAS1, BIAS); + if (isLoad1Unmasked) { + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + } else { + vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA); + if (!isBeta0) { + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break; + case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break; + case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]); + break; + case 1: + vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]); + break; + case 2: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]); + break; + case 4: + vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]); + break; + case 5: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } + + if (!isBetaN) { + vaddps(Ymm(i + 4), ymm0, Ymm(i + 4)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 4), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4)); + } + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break; + case 1: + vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break; + case 4: + vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 1: + vmaskmovps( + ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + case 3: + vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 4: + vmaskmovps( + ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + } + } + + if (unroll_m >= 16) { + // Re-use ymm4 (VBIAS2) + if (i == 0) { + if (hasBias) { + if (isLoad1Unmasked) { + vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]); + } else { + vmaskmovps( + VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]); + } + } + } + vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA); + if (!isBeta0) { + if (isLoad2Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break; + case 1: + vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break; + case 4: + vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]); + break; + case 1: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmaskmovps(ymm0, VMASK, + ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]); + break; + case 4: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmaskmovps(ymm0, VMASK, + ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(Ymm(i + 10), ymm0, Ymm(i + 10)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 10), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10)); + } + if (isLoad2Unmasked) { + switch (i) { + case 0: + vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10)); + break; + case 1: + vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + case 3: + vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10)); + break; + case 4: + vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 1: + vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 3: + vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 4: + vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + } + } + } + if (i == 2) + add(CO1, rax); + } + if (unroll_n >= 4) { + add(CO2, rax); + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15)); + }; + + auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 16 rows, 6 columns at a time (handling tail + // cases appropriately). + // Masking is used for tail cases where M is not divisible by 8. + auto subloop = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + if (isTransA) { + do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked); + } + + Label subloop11, subloop11mask; + Label subloop20, subloop21, subloop22, subloop23; + Label subloop24, subloop25; + Label subloop30, subloop31, subloop32, subloop33; + Label subloop34, subloop35; + Label subloop98, subloop98mask; + Label subloop99, subloop99mask; + + mov(CO1, C); + lea(CO2, ptr[CO1 + LDC * 2]); + add(CO2, LDC); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, qword[B + LDB3]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(subloop20, T_NEAR); + } + align(16); + + if (!isTransA) { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } + } else { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(subloop20, T_NEAR); + align(16); + + L(subloop11); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + L(subloop20); + cmp(I, 1); + jne(subloop21, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop21); + cmp(I, 2); + jne(subloop22, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop22); + cmp(I, 3); + jne(subloop23, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop23); + cmp(I, 4); + jne(subloop24, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop24); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(subloop25, T_NEAR); + align(16); + + L(subloop31); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + L(subloop25); + cmp(I, 1); + jne(subloop32, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop32); + cmp(I, 2); + jne(subloop33, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop33); + cmp(I, 3); + jne(subloop34, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop34); + cmp(I, 4); + jne(subloop35, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop35); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + align(16); + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + sal(rax, 4); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + for (int i = 0; i < 8; i++) { + mov(dword[rsp + 88 + i * 4], i); + } + + if (isTransA && is_avx2) { + movq(xmm0, LDA); + vpbroadcastq(ymm1, xmm0); + vinsertf128(ymm0, ymm0, xmm0, 1); + vpermilpd(ymm0, ymm0, 5); + vpaddq(ymm1, ymm1, ymm1); + vperm2f128(ymm1, ymm1, ymm1, 8); + vpaddq(ymm0, ymm0, ymm1); + vmovups(STRIDE, ymm0); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, 0x1f); + mov(FLAG, rax); + + Label main0, main1, main2, main3, main999; + + cmp(M, UNROLL_M); + jl(main0, T_NEAR); + align(16); + + L(main1); + subloop(UNROLL_M, true, true); + sub(M, UNROLL_M); + cmp(M, UNROLL_M); + jge(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 0); + jle(main999, T_NEAR); + + if (UNROLL_M > 8) { + cmp(M, 8); + jle(main2, T_NEAR); + + sub(M, 8); + vbroadcastss(VMASK, M); + vpcmpgtd(VMASK, VMASK, MASK); + + subloop(16, true, false); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 8); + jne(main3, T_NEAR); + subloop(8, true, true); + jmp(main999, T_NEAR); + } + + align(16); + + L(main3); + vbroadcastss(VMASK, M); + if (is_avx2) { + vpcmpgtd(VMASK, VMASK, MASK); + } else { + auto xmask = Xmm(VMASK.getIdx()); + auto xmm_tmp = xmm4; + + vextractf128(xmm_tmp, VMASK, 1); + vpcmpgtd(xmask, xmask, MASK); + vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4 + vinsertf128(VMASK, VMASK, xmm_tmp, 1); + } + subloop(8, false, false); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032; + int BN = isTransA ? 96 : 48; + int BK = isTransB ? 96 : 256; + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 16 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp new file mode 100644 index 0000000000..aabf520a3c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX_GEMM_F32_HPP +#define JIT_AVX_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp new file mode 100644 index 0000000000..5147885a89 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp @@ -0,0 +1,346 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "gemm_utils_f32.hpp" +#include "ref_gemm_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace gemm_utils; + +namespace { + +template +void copy_A( + bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) { + for (int k = 0; k < K; k++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; + } + ws += unroll_factor::m; + } +} + +template +void kernel_mxn(int K, const data_t *A, const dim_t lda, + const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc, + const data_t alpha, const data_t beta) { + data_t c[unroll_factor::m * unroll_factor::n] = + { static_cast(0.) }; + for (int k = 0; k < K; k++) { + for (int j = 0; j < unroll_factor::n; j++) { + data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; + c[i + unroll_factor::m * j] += a * b; + } + } + } + for (int j = 0; j < unroll_factor::n; j++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor::m; i++) { + C[i + j * ldc] = (beta == static_cast(0.)) + ? alpha * c[i + unroll_factor::m * j] + : alpha * c[i + unroll_factor::m * j] + + beta * C[i + j * ldc]; + } + } +} + +template +void block_ker(const int M, const int N, const int K, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + data_t *C, const dim_t ldc, const data_t alpha, const data_t beta, + data_t *ws, bool do_copy) { + int Nu = rnd_dn(N, unroll_factor::n); + int Mu = rnd_dn(M, unroll_factor::m); + for (int i = 0; i < Mu; i += unroll_factor::m) { + for (int j = 0; j < Nu; j += unroll_factor::n) { + const data_t *b = isTransB ? &B[j] : &B[j * ldb]; + const data_t *a = isTransA ? &A[i * lda] : &A[i]; + if (do_copy) { + if (j == 0) { + copy_A(isTransA, K, a, lda, ws); + } + kernel_mxn( + K, ws, unroll_factor::m, b, ldb, + &C[i + j * ldc], ldc, alpha, beta); + } else { + kernel_mxn( + K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); + } + } + } + // tail processing + for (int i = 0; i < M; i++) { + for (int j = Nu; j < N; j++) { + data_t c = beta == static_cast(0.) + ? static_cast(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } + for (int i = Mu; i < M; i++) { + for (int j = 0; j < Nu; j++) { + data_t c = beta == static_cast(0.) + ? static_cast(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } +} + +template +void gemm_ithr(const int M, const int N, const int K, const data_t alpha, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + const data_t beta, data_t *C, const dim_t ldc, bool do_copy, + data_t *ws) { + constexpr int BM = gemm_traits::BM; + constexpr int BN = gemm_traits::BN; + constexpr int BK = gemm_traits::BK; + + const data_t *curA; + const data_t *curB; + data_t *curC; + + if ((M <= 0) || (N <= 0)) + return; + + if ((K <= 0) || (alpha == static_cast(0))) { + dim_t MN = N * M; + if (beta == static_cast(0.)) { + for (dim_t j = 0; j < MN; j++) + C[j] = static_cast(0.); + } else if (beta != static_cast(1.)) { + for (dim_t j = 0; j < MN; j++) + C[j] *= beta; + } + return; + } + + for (int Bk = 0; Bk < K; Bk += BK) { + int kb = nstl::min(K - Bk, BK); + for (int Bm = 0; Bm < M; Bm += BM) { + int mb = nstl::min(M - Bm, BM); + for (int Bn = 0; Bn < N; Bn += BN) { + int nb = nstl::min(N - Bn, BN); + curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; + curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; + curC = C + Bm + Bn * ldc; + if (Bk == 0) { + block_ker(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, beta, ws, do_copy); + } else { + block_ker(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, static_cast(1.0), + ws, do_copy); + } + } + } + } +} + +} + +template +mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, const int *M_, + const int *N_, const int *K_, const data_t *alpha_, const data_t *A, + const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_, + data_t *C, const int *ldc_, const data_t *bias) { + + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); + const int M = *M_, N = *N_, K = *K_; + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; + const data_t alpha = *alpha_, beta = *beta_; + + int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads(); + int nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + // thread balancing over M, N, K & size of blocking dimensions + calc_nthr_nocopy_avx( + M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + data_t *c_buffers = nullptr; + data_t *ws_buffers = nullptr; + if (nthr_k > 1) { + c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(data_t), PAGE_4K); + if (!c_buffers) { + nthr_k = 1; + KB = K; + } + } + + bool do_copy = (NB / unroll_factor::n > 3); + const int nthr_mn = nthr_m * nthr_n; + const int nthr = nthr_mn * nthr_k; + const size_t ws_elems_per_thr = K * unroll_factor::m; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); + if (do_copy) { + ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K); + if (!ws_buffers) + do_copy = false; + } + + auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N, + int ithr) { + from = NB * (ithr); + to = NB * (ithr + 1); + if (to > N) + to = N; + myN = to - from; + }; + + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_n = ithr_mn / nthr_m; + int ithr_k = ithr / nthr_mn; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + data_t *ws = do_copy + ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) + : nullptr; + + int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, + k_from = 0, k_to = 0, myK = 0; + + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(k_from, k_to, myK, KB, K, ithr_k); + + if (myM > 0 && myN > 0) { + data_t myBeta, *myC; + dim_t ld; + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0f; + ld = MB; + } + const data_t *myA = isTransA + ? &(A[k_from + m_from * lda]) + : &(A[m_from + k_from * lda]); + const data_t *myB = isTransB + ? &(B[n_from + k_from * ldb]) + : &(B[k_from + n_from * ldb]); + + if (!isTransA) { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } else { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } + } + }); + + if (nthr_k > 1) { + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_k = ithr / nthr_mn; + int ithr_n = ithr_mn / nthr_m; + + int n_from = 0, n_to = 0, myN = 0; + int m_from = 0, m_to = 0, myM = 0; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + + // sum matrices partitioned along K dimension + int offset = 0, block = 0; + gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset, + &block); + for (int ik = 1; ik < nthr_k; ++ik) { + data_t *myC = c_buffers + + MB * ((dim_t)NB * (cbase + ik - 1) + offset); + + gemm_utils::sum_two_matrices(myM, block, myC, MB, + &C[m_from + (n_from + offset) * ldc], ldc); + } + }); + } + + if (bias) { + parallel_nd(N, M, [&](int i, int j) { + C[i*ldc + j] += bias[j]; + }); + } + + free(ws_buffers); + free(c_buffers); + + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const float *alpha_, + const float *A, const int *lda_, const float *B, const int *ldb_, + const float *beta_, float *C, const int *ldc_, const float *bias); + +template mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const double *alpha_, + const double *A, const int *lda_, const double *B, const int *ldb_, + const double *beta_, double *C, const int *ldc_, const double *bias); +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp new file mode 100644 index 0000000000..7c90ba6277 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_F32_HPP +#define REF_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M, + const int *N, const int *K, const data_t *alpha, const data_t *A, + const int *lda, const data_t *B, const int *ldb, const data_t *beta, + data_t *C, const int *ldc, const data_t *bias); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp new file mode 100644 index 0000000000..3dbe07d743 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "mkldnn_traits.hpp" +#include "nstl.hpp" + +#include "jit_generator.hpp" + +#include "gemm.hpp" + +#include "f32/jit_avx512_common_gemm_f32.hpp" +#include "f32/jit_avx_gemm_f32.hpp" +#include "f32/ref_gemm_f32.hpp" + +#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp" +#include "s8x8s32/simple_gemm_s8s8s32.hpp" +#include "s8x8s32/ref_gemm_s8x8s32.hpp" + +#include "os_blas.hpp" + +/* USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes use Intel(R) MKL CBLAS + * yes no use jit + * no yes system-dependent CBLAS + * no no use jit + */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t check_gemm_input(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const int *lda, + const int *ldb, const int *ldc, const float *alpha, const float *beta, + const bool with_bias) { + if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta)) + return mkldnn_invalid_arguments; + if (with_bias && *beta != 0) + return mkldnn_unimplemented; + bool consistency = true + && utils::one_of(*transa, 'T', 't', 'N', 'n') + && utils::one_of(*transb, 'T', 't', 'N', 'n') + && *M >= 0 + && *N >= 0 + && *K >= 0; + + if (!consistency) + return mkldnn_invalid_arguments; + bool isTransA = utils::one_of(*transa, 'T', 't'); + bool isTransB = utils::one_of(*transb, 'T', 't'); + int nrowA = isTransA ? *K : *M; + int nrowB = isTransB ? *N : *K; + consistency = true + && *lda >= nstl::max(1, nrowA) + && *ldb >= nstl::max(1, nrowB) + && *ldc >= nstl::max(1, *M); + if (!consistency) + return mkldnn_invalid_arguments; + + return mkldnn_success; +} + +mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, + const char *transa, const char *transb, const int *M, const int *N, + const int *K, const int *lda, const int *ldb, const int *ldc, + const float *alpha, const float *beta, const bool with_bias) { + if (offsetc == nullptr) + return mkldnn_invalid_arguments; + if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) + return mkldnn_invalid_arguments; + + return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, + beta, with_bias); +} + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias, const bool force_jit_gemm) { + mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, + lda, ldb, ldc, alpha, beta, bias != nullptr); + if (status != mkldnn_success) + return status; + +#ifdef USE_CBLAS + if (!force_jit_gemm) { + bool trA = *transa == 't' || *transa == 'T'; + bool trB = *transb == 't' || *transb == 'T'; + CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; + cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, + *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); + + if (bias) { + // Add bias if necessary (bias is applied to columns of C) + cblas_int incx = 1, incy = 1; + parallel_nd(*N, [&](int n) { + ptrdiff_t offset = (ptrdiff_t)n * (*ldc); + cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); + }); + } + return mkldnn_success; + } +#endif + + if (mayiuse(avx512_common)) + return jit_avx512_common_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else if (mayiuse(avx)) + return jit_avx_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else + return ref_gemm(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, + M, N, K, LDA, LDB, LDC, alpha, beta, false); + if (status != mkldnn_success) + return status; + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + +#if USE_MKL_IGEMM + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + if (data_traits::data_type == data_type::u8) { + CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; + CBLAS_OFFSET Cblas_offsetc = + OCisR + ? CblasRowOffset + : OCisC + ? CblasColOffset + : CblasFixOffset; + cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, + *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo, + *beta, C, *LDC, co); + return mkldnn_success; + } else { + assert(data_traits::data_type == data_type::s8); + // TODO CBLAS implementation of gemm_s8s8s32 goes here. + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if (utils::everyone_is(0, *ao, *bo)) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#else + cpu_isa_t isa = isa_any; + if (mayiuse(avx512_core_vnni)) { + isa = avx512_core_vnni; + } else if (mayiuse(avx512_core)) { + isa = avx512_core; + } + + if (data_traits::data_type == data_type::u8) { + switch (isa) { + case avx512_core: + case avx512_core_vnni: + return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta, + C, LDC, co); + default: + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } else { + assert(data_traits::data_type == data_type::s8); + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni)) + && *ao == 0 && *bo == 0) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#endif +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::cpu; + +mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, + const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha, + const float *A, const int64_t *lda, const float *B, const int64_t *ldb, + const float *beta, float *C, const int64_t *ldc) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32); +} + +mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} + +mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp new file mode 100644 index 0000000000..dc15ff7130 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp @@ -0,0 +1,58 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_HPP +#define GEMM_HPP + +#include "mkldnn_types.h" +#include "os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias = nullptr, bool force_jit_gemm = false); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const b_dt *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co); + +#ifdef USE_CBLAS +#define GEMM_IMPL_STR "gemm:blas" +#else +#define GEMM_IMPL_STR "gemm:jit" +#endif + +#if USE_MKL_IGEMM +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas" +#else +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit" +#endif + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp new file mode 100644 index 0000000000..4d34ede0bd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef OS_BLAS_HPP +#define OS_BLAS_HPP + +/** \file + * Common stuff respecting USE_MKL and USE_CBLAS compile flags + * + * USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS + * yes no jit calls OK; assert if cblas is ever called + * no yes system-dependent CBLAS + * no no gemm convolution (or other blas) N/A; create stubs + */ + +#if defined(USE_MKL) + +#include "mkl_version.h" + +#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001) +#define USE_MKL_IGEMM \ + (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628) + +#include "mkl_cblas.h" +#if !defined(USE_CBLAS) +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif + +#else /* defined(USE_MKL) */ + +#define USE_MKL_PACKED_GEMM 0 +#define USE_MKL_IGEMM 0 + +#if defined(_SX) +/* TODO: _SX should also define USE_CBLAS in case the later is available */ +extern "C" { +#include "cblas.h" // CHECK: does SX also have a fortran API sgemm? +} + +#elif defined(USE_CBLAS) +#include "cblas.h" // Maybe a system/cmake cblas works for you? +#else +/* put the stubs to make a code compilable but not workable */ +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif /* defined(_SX) */ + +#endif /* defined(USE_MKL) */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +#if defined(USE_MKL) && defined(USE_CBLAS) +typedef MKL_INT cblas_int; + +#elif defined(USE_CBLAS) +typedef int cblas_int; + +#if defined(_SX) +/* this cblas.h is peculiar... */ +typedef CBLAS_ORDER CBLAS_LAYOUT; +#endif +#endif + +} +} +} + +#endif /* OS_BLAS_HPP */ + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp new file mode 100644 index 0000000000..dde72f4a17 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp @@ -0,0 +1,206 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_H +#define COMMON_H + +#define GEMM_CODE_SIZE (4096L * 32) + +#define AVX512_UNROLL_M 48 +#define AVX512_UNROLL_N 8 +#define AVX512_UNROLL_K 1 +#define AVX512_BM 9984 +#define AVX512_BN 384 +#define AVX512_BK 768 +#define AVX512_BK_VNNI 1536 +#define AVX512_BK_TRADITIONAL 384 +#define AVX512_BLOCKING_SMALL_K 48 +#define AVX512_BN_SMALL_K 24 + + +#define PAGESIZE 4096 + +#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE +#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +enum { + PARTITION_1D_ROW, + PARTITION_1D_COL, + PARTITION_2D_COL_MAJOR, + PARTITION_2D = PARTITION_2D_COL_MAJOR, +}; + +enum { + COPY_NONE, + COPY_A, +}; + +enum { + NO_OFFSET, + FIX_OFFSET, + COL_OFFSET, + ROW_OFFSET, +}; + +// Alias for any dimension related variable. +typedef long long int dim_t; + +typedef struct { + // Interface arguments. + int transa, transb, offsetc; + dim_t m, n, k; + dim_t lda, ldb, ldc; + const int8_t *a; + const uint8_t *b; + int32_t *c; + const float *alpha, *beta; + + int8_t ao, bo; + const int32_t *co; + + // Kernel parameters. + dim_t um, un, uk, bm, bn, bk; + dim_t bn_small_k, bk_traditional, blocking_small_k; + + int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + // Gemv kernels + void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + // Gemv parameters + int swap; + +} blas_t; + + +class jit_avx512_core_u8_copy_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern); + + public: + jit_avx512_core_u8_copy_an_kern(); +}; + +class jit_avx512_core_u8_copy_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern); + + public: + jit_avx512_core_u8_copy_at_kern(); +}; + +class jit_avx512_core_u8_copy_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern); + + public: + jit_avx512_core_u8_copy_bn_kern(); +}; + +class jit_avx512_core_u8_copy_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern); + + public: + jit_avx512_core_u8_copy_bt_kern(); +}; + +class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern); + + public: + jit_avx512_core_u8_copy_sum_an_kern(); +}; + +class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern); + + public: + jit_avx512_core_u8_copy_sum_at_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern); + + public: + jit_avx512_core_u8_copy_sum_bn_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern); + + public: + jit_avx512_core_u8_copy_sum_bt_kern(); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp new file mode 100644 index 0000000000..db9dd9ef97 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg); +int gemv_threading_driver(blas_t *arg); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp new file mode 100644 index 0000000000..e4b8e1cde2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp @@ -0,0 +1,1409 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "common.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_gemm_s8u8s32.hpp" +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" +#include "gemv.hpp" + +#if defined(_MSC_VER) +#include +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef struct { + int nthrs_m, nthrs_n; + int partition; + int copy_type; +} blas_thread_t; + +static inline void round_to_nearest(int32_t *rounded_val, double fp_val) { + if (fp_val >= 0.) { + fp_val += 0.5; + if (fp_val > INT32_MAX) { + fp_val = INT32_MAX; + } + } else { + fp_val -= 0.5; + if (fp_val < INT32_MIN) { + fp_val = INT32_MIN; + } + } + *rounded_val = (int32_t) fp_val; +} + +static inline void add_results(const dim_t m, const dim_t n, const dim_t k, + const float alpha, const float beta, const int32_t *c_partial_sum, + const dim_t ldcp, int32_t *c_data, const dim_t ldc, + const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao, + const int8_t bo, const int32_t *co, const int offsetc) +{ + for (dim_t j = 0; j < n; ++j) { + for (dim_t i = 0; i < m; ++i) { + int32_t ctemp = c_partial_sum[i + j * ldcp]; + + if (alpha == 1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float += (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else if (alpha == -1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = -ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float -= (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else { + if (beta == 0.0f) { + double c_float = alpha * (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + double c_float = alpha * (double) ctemp + + beta * (double) c_data[i + j * ldc]; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } + + if (offsetc == FIX_OFFSET) { + c_data[i + j * ldc] += co[0]; + } else if (offsetc == ROW_OFFSET) { + c_data[i + j * ldc] += co[j]; + } else if (offsetc == COL_OFFSET) { + c_data[i + j * ldc] += co[i]; + } + } + } +} + +// TODO Find a better place for those functions. +static inline dim_t ld_padd(const dim_t x) +{ + return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t))) + * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t)); +} + +void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, float beta, int32_t *c, + const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum, + const int32_t *co, const int offsetc, const blas_t *arg) +{ + int8_t ao = arg->ao; + int8_t bo = arg->bo; + int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0]; + + // Since m and n are limited by blocking, stack overflow may not happen; + // it's up to 32kB +#if !defined(_MSC_VER) + int32_t col_offset[m]; + int32_t row_offset[n]; +#else + int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m); + int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n); +#endif + + int col_req = 0; + int row_req = 0; + + if ((bo != 0) || (offsetc == COL_OFFSET)) + col_req = 1; + if ((ao != 0) || (offsetc == ROW_OFFSET)) + row_req = 1; + + // It needs one of colum or row offsets, but it doesn't need both + if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) { + if ((col_req == 0) && (row_req == 0)) { + if (m <= n) { + col_req = 1; + } else { + row_req = 1; + } + } + } + + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] = 0; + + if (offsetc == COL_OFFSET) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co[i]; + } + + if (bo != 0) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += bo * a_row_sum[i]; + } + } + + if (row_req) { + for (dim_t i = 0; i < n; i++) + row_offset[i] = 0; + + if (offsetc == ROW_OFFSET) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co[i]; + } + + if (ao != 0) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += ao * b_col_sum[i]; + } + } + + if ((offsetc == FIX_OFFSET) && (co_0 != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co_0; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co_0; + } + } + + if ((ao != 0) && (bo != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += (int32_t) k * ao * bo; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += (int32_t) k * ao * bo; + } + } + + if (col_req == 0) { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } else { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } +} + +static inline void *align(void *ptr, size_t alignment) +{ + return (void *) utils::rnd_up((uintptr_t) ptr, alignment); +} + +static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co, + const blas_t *arg) +{ + dim_t lda = arg->lda; + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + float beta = *arg->beta; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m_padd); + + dim_t strideAm = (arg->transa == 0)? 1 : lda; + dim_t strideAn = (arg->transa != 0)? 1 : lda; + dim_t strideBm = (arg->transb == 0)? 1 : ldb; + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t a_buf_nelems = m_padd * k_padd; + size_t b_buf_nelems = k_padd * n_padd; + size_t a_row_sum_nelems = m_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K + + b_buf_nelems * sizeof(*b) + PAGE_4K + + a_row_sum_nelems * sizeof(*c) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems, + PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + float beta_saved = beta; + + int a_block_copied = 0; + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first time + if (Bk == 0) + beta = beta_saved; + else + beta = 1.0f; + + // Apply C offset when to the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn; + arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL, + NULL, b_col_sum); + + dim_t sizeUM = 0; + for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { + sizeUM = sizeM - Um; + if (sizeUM > arg->um) + sizeUM = arg->um; + + /* + * Use the whole A buffer only if we have multiple B blocks + * for k-dimension, otherwise we are wasting cache to store + * B and C blocks. + */ + dim_t Um_forA = 0; + if (sizeN < n) + Um_forA = Um; + + const int8_t *a_block = a + (Bm + Um) * strideAm + + Bk * strideAn; + if (!a_block_copied) { + arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL, + bufferA + Um_forA * sizeK, NULL, NULL, + a_row_sum + Um_forA); + } + + int32_t *c_block = c + (Bm + Um) + Bn * ldc; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm + Um; + } + if (need_c_buffer) { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, 0.0f, + bufferC + Um, ldc_buf, a_row_sum + Um_forA, + b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta + // and offsets. + add_results(sizeUM, sizeN, sizeK, alpha, beta, + bufferC + Um, ldc_buf, c_block, ldc, + a_row_sum + Um_forA, b_col_sum, ao, bo, + co + co_stride, offsetc); + } else { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, beta, + c_block, ldc, a_row_sum + Um_forA, b_col_sum, + co + co_stride, offsetc, arg); + } + } + a_block_copied = 1; + } + a_block_copied = 0; + } + } + + free(mem); + + return 0; +} + +static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n, + const dim_t k, const int8_t *bufferA, const uint8_t *b, + const float beta, int32_t *c, const int offsetc, const int32_t *co, + const int32_t *a_row_sum, const blas_t *arg) +{ + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m); + + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t b_buf_nelems = k * n_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + // Implement the kernel here. + const uint8_t *b_block = b + Bn * strideBn; + arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL, + b_col_sum); + + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + int32_t *c_block = c + Bn * ldc; + if (need_c_buffer) { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC, + ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta and offsets. + add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block, + ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride, + offsetc); + } else { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block, + ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg); + } + } + + free(mem); + + return 0; + +} + +#define N2D_MAX_AVX512 384 +#define M2D_MIN_AVX512 384 +#define VECLEN 16 +#define NCONS 1 +static inline void set_thread_opts_avx512(int *p_nthrs, + blas_thread_t *thread_info, const blas_t *arg) +{ + int nthrs = *p_nthrs; + dim_t m = arg->m; + dim_t n = arg->n; + + thread_info->nthrs_m = 0; + thread_info->nthrs_n = 0; + thread_info->copy_type = COPY_NONE; // By default don't do parallel copy. + + int condition_2D_bsrc = -1; + if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) { + condition_2D_bsrc = 1; + } else { + condition_2D_bsrc = 0; + } + + int condition_1D_copya = 0; + if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + // If offset is non-zero, we need to keep 1D_copya to reduce update overhead + if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0 + || arg->offsetc != FIX_OFFSET) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + if (condition_2D_bsrc == 1) { + int nthrs_m = 1; + int nthrs_n = nthrs; + + while ((nthrs_n % 2 == 0) && + (n / nthrs > N2D_MAX_AVX512 || + n / nthrs_n <= N2D_MAX_AVX512 / 2) && + (m / nthrs_m >= 2 * M2D_MIN_AVX512) && + (nthrs_m < 4)) { + nthrs_m *= 2; + nthrs_n /= 2; + } + + thread_info->nthrs_m = nthrs_m; + thread_info->nthrs_n = nthrs_n; + thread_info->partition = PARTITION_2D; + + // Reset the total number of threads that will be used. + *p_nthrs = nthrs_m * nthrs_n; + + } else if (condition_1D_copya && mkldnn_thr_syncable()) { + // Use parallel copy A algorithm + thread_info->copy_type = COPY_A; + thread_info->partition = PARTITION_1D_COL; + } else { + if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) { + thread_info->partition = PARTITION_1D_ROW; + } else { + thread_info->partition = PARTITION_1D_COL; + } + } +} +#undef N2D_MAX_AVX512 +#undef M2D_MIN_AVX512 +#undef VECLEN +#undef NCONS + +static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, + dim_t *t_offset, dim_t *t_block) +{ + dim_t band = n / nthrs; + + dim_t tail = n - (nthrs - 1) * band; + if (tail > (band + 1)) + band++; + tail = n - (nthrs - 1) * band; + + if (ithr < (nthrs - 1)) + *t_block = band; + else + *t_block = tail; + + *t_offset = ithr * band; + + if (*t_offset >= n) { + *t_block = 0; + *t_offset = 0; + } else if ((*t_offset + *t_block) > n) { + *t_block = n - *t_offset; + } +} + +static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, + const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, + const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp, + dim_t *p_n_band) +{ + dim_t m_disp = 0, n_disp = 0; + dim_t m_band = 0, n_band = 0; + + int mdiv = nthrs_m; + int ndiv = nthrs_n; + + dim_t m_bandt = m / mdiv; /* size per thread */ + dim_t n_bandt = n / ndiv; /* size per thread */ + int firstmgroup = mdiv - 1; + int firstngroup = ndiv - 1; + dim_t firstmval = m_bandt; + dim_t firstnval = n_bandt; + + int mthr_used = mdiv; + if (m - (mdiv - 1) * m_bandt > m_bandt + 1) { + if (m - (mdiv - 1) * m_bandt > mdiv) + ++m_bandt; + + firstmval = m_bandt + 1; + mthr_used = (int) (m / firstmval); + + if (mthr_used * firstmval < m) + ++mthr_used; + + firstmgroup = mthr_used - 1; + } + + int nthr_used = ndiv; + if (n - (ndiv - 1) * n_bandt > n_bandt + 1) { + firstnval = n_bandt + 1; + nthr_used = (int) (n / firstnval); + + if (nthr_used * firstnval < n) + ++nthr_used; + + firstngroup = nthr_used - 1; + } + + *nthrs = mthr_used * nthr_used; + + if (ithr < *nthrs) { + if (ithr_i < firstmgroup) { + m_band = firstmval; + m_disp = ithr_i * firstmval; + } else if (ithr_i <= mthr_used - 2) { + m_band = m_bandt; + m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt; + } else { + m_disp = firstmgroup * firstmval + + (mthr_used - 1 - firstmgroup) * m_bandt; + m_band = nstl::max(0LL, m - m_disp); + } + + if (ithr_j < firstngroup) { + n_band = firstnval; + n_disp = ithr_j * firstnval; + } else if (ithr_j <= nthr_used - 2) { + n_band = n_bandt; + n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt; + } else { + n_disp = firstngroup * firstnval + + (nthr_used - 1 - firstngroup) * n_bandt; + n_band = nstl::max(0LL, n - n_disp); + } + m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL); + n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL); + } + + if (ithr < *nthrs) { + *p_m_disp = m_disp; + *p_n_disp = n_disp; + *p_m_band = m_band; + *p_n_band = n_band; + } else { + *p_m_disp = 0; + *p_n_disp = 0; + *p_m_band = 0; + *p_n_band = 0; + } + + return; +} + +static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m, + dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c, + const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg) +{ + dim_t strideAm = (arg->transa == 0)? 1 : arg->lda; + dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb; + int offsetc = arg->offsetc; + + switch (thread_info->partition) { + case PARTITION_1D_ROW: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->m, &offset, &block); + + *m = block; + *n = arg->n; + *k = arg->k; + + // Set matrix A. + *a = arg->a + offset * strideAm; + + // Set matrix B. + *b = arg->b; + + // Set matrix C. + *c = arg->c + offset; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = offset; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_1D_COL: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->n, &offset, &block); + + *m = arg->m; + *n = block; + *k = arg->k; + + // Set matrix A. + *a = arg->a; + + // Set matrix B. + *b = arg->b + offset * strideBn; + + // Set matrix C. + *c = arg->c + offset * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = offset; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_2D_COL_MAJOR: + { + int nthrs_m = thread_info->nthrs_m; + int nthrs_n = thread_info->nthrs_n; + int ithr_i = ithr % nthrs_m; + int ithr_j = ithr / nthrs_m; + + dim_t m_disp = 0; + dim_t m_band = 0; + dim_t n_disp = 0; + dim_t n_band = 0; + + partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n, + arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band); + + *m = m_band; + *n = n_band; + *k = arg->k; + + // Set matrix A. + *a = arg->a + m_disp * strideAm; + + // Set matrix B. + *b = arg->b + n_disp * strideBn; + + // Set matrix C. + *c = arg->c + m_disp + n_disp * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = n_disp; + } else if (offsetc == COL_OFFSET) { + co_stride = m_disp; + } + *co = arg->co + co_stride; + break; + } + } +} + +#define MULTIPLIER 10 +static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m, + const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b, + int32_t *c, const int32_t *co, const blas_t *arg, + char **p_shared_mem) +{ + const dim_t lda = arg->lda; + const dim_t ldb = arg->ldb; + const dim_t strideAm = (arg->transa == 0)? 1 : lda; + const dim_t strideAn = (arg->transa != 0)? 1 : lda; + const dim_t strideBm = (arg->transb == 0)? 1 : ldb; + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs; + if (m_padd > m) { + m_padd = utils::rnd_up(m, arg->um); + } + + size_t a_buf_nelems = m_padd * k_padd; + + // Allocate shared memory for A and its row sum buffers in master thread. + if (ithr == 0) { // If thread master + size_t a_row_sum_nelems = m_padd; + + size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K) + + a_row_sum_nelems * sizeof(*c) + PAGE_4K; + + *p_shared_mem = (char *) malloc(mem_size, 128); + + } + mkldnn_thr_barrier(); + + char *mem = *p_shared_mem; + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K); + + if (!mem) { + return -1; + } + + int result = 0; // Return status + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first term of partial sum. + float beta = 1.0f; + if (Bk == 0) + beta = *(arg->beta); + + // Apply C offset for the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + if (ithr < nthrs) { + dim_t band = (sizeM + nthrs - 1) / nthrs; + band = utils::rnd_up(band, arg->um); + + dim_t offset = band * ithr; + + // If offset is too large don't use that thread for copying. + if (offset >= sizeM) { + offset = 0; + band = 0; + } + + // Handle the tail of the copy. + if (offset + band > sizeM) { + band = sizeM - offset; + } + + if (band > 0) { + const int8_t *a_block = a + (Bm + offset) * strideAm + + Bk * strideAn; + arg->copyA(&sizeK, &band, a_block, &lda, NULL, + bufferA + offset * sizeK, NULL, NULL, + a_row_sum + offset); + } + } + mkldnn_thr_barrier(); // Wait for finishing parallel copy. + + const uint8_t *b_block = b + Bk * strideBm; + int32_t *c_block = c + Bm; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm; + } + + result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK, + bufferA, b_block, beta, c_block, offsetc, co + co_stride, + a_row_sum, arg); + + mkldnn_thr_barrier(); // Wait for kernel computations to finish. + } + } + + // Free memory allocated in master thread + if (ithr == 0) { + free(mem); + } + + return result; +} +#undef MULTIPLIER + +static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k, + double fp_per_cycle, int *nthrs) +{ + double omp_overhead_small_core = 3.0e+3; + double omp_intercept_big_core = 4.0e+3; + double omp_slope_big_core = 5.0e+2; + + double gemm_cycles = 8.0 * m * n * k / fp_per_cycle; + + int i = *nthrs; + + // Use a different model for omp overheads if nthrs is <= 4 + if (*nthrs <= 4 && omp_overhead_small_core > 0) { + double omp_cycles = omp_overhead_small_core; + if (gemm_cycles < omp_cycles) { + *nthrs = 1; + return; + } else { + while (i > 1) { + if (omp_cycles * i < gemm_cycles * (i - 1)) break; + --i; + } + } + } else { + if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { + *nthrs = 1; + return; + } + + // adaptive decrement to march faster· + while (i > 1) { + double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; + if (omp_cycles * i < gemm_cycles * (i - 1)) + break; + + if (i < 10) + i -= 2; + else if (i < 30) + i -= 4; + else + i -= 8; + } + } + + if (i < 1) + i = 1; + + *nthrs = i; +} + +#define CACHE_LINE_SIZE 64 +static int gemm_threading_driver(blas_t *arg) +{ + if ((arg->m <= 0) || (arg->n <= 0)) + return mkldnn_success; + + if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { + return mkldnn_success; + } + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr); + + if (nthr == 1) { + return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b, + arg->c, arg->co, arg); + } + + int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE, + PAGE_4K); + + if (!results) { + return -1; + } + + for (int i = 0; i < nthr; i++) { + results[i * CACHE_LINE_SIZE] = 0; // Initialize to success + } + + char *shared_mem = NULL; + + parallel(nthr, [&](const int ithr, const int nthr) { + int nthrs = nthr; + if (nthrs == 1) { + results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, + arg->b, arg->c, arg->co, arg); + } else { + blas_thread_t thread_info; + set_thread_opts_avx512(&nthrs, &thread_info, arg); + + const int8_t *a = NULL; + const uint8_t *b = NULL; + int32_t *c = NULL; + const int32_t *co = NULL; + dim_t m = -1; + dim_t n = -1; + dim_t k = -1; + decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co, + &thread_info, arg); + + if (ithr < nthrs) { + switch (thread_info.copy_type) { + case COPY_A: + results[ithr * CACHE_LINE_SIZE] = + parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg, + &shared_mem); + break; + + default: + case COPY_NONE: + results[ithr * CACHE_LINE_SIZE] = + gemm_kernel_driver(m, n, k, a, b, c, co, arg); + break; + } + } + } + }); + + int result = 0; // Initialize to success + for (int i = 0; i < nthr; i++) { + if (results[i] != 0) { + result = results[i * CACHE_LINE_SIZE]; + break; + } + } + + free(results); + + return result; +} +#undef CACHE_LINE_SIZE + +static jit_avx512_core_u8_copy_an_kern *copy_an; +static jit_avx512_core_u8_copy_at_kern *copy_at; +static jit_avx512_core_u8_copy_bn_kern *copy_bn; +static jit_avx512_core_u8_copy_bt_kern *copy_bt; +static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an; +static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at; +static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn; +static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt; +static jit_avx512_core_gemm_s8u8s32_kern *kernel; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_c; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel; + +static void jit_init(blas_t *arg) +{ + static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + if (mayiuse(avx512_core_vnni)) { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK_VNNI; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } else { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } + + static std::once_flag initialized; + std::call_once(initialized, []{ + + copy_an = new jit_avx512_core_u8_copy_an_kern(); + copy_at = new jit_avx512_core_u8_copy_at_kern(); + copy_bn = new jit_avx512_core_u8_copy_bn_kern(); + copy_bt = new jit_avx512_core_u8_copy_bt_kern(); + + copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern(); + copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern(); + copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern(); + copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern(); + + kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false); + kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true); + kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true); + kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false); + kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false); + kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true); + kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true); + kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false); + + gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + + + copyAn = copy_an->getCode(); + + copyAt = copy_at->getCode(); + + copyBn = copy_bn->getCode(); + + copyBt = copy_bt->getCode(); + + copySumAn = copy_sum_an->getCode(); + + copySumAt = copy_sum_at->getCode(); + + copySumBn = copy_sum_bn->getCode(); + + copySumBt = copy_sum_bt->getCode(); + + kern = kernel->getCode(); + + kern_b = kernel_b->getCode(); + + kern_r = kernel_r->getCode(); + + kern_c = kernel_c->getCode(); + + kern_b0 = kernel_b0->getCode(); + + kern_b0_b = kernel_b0_b->getCode(); + + kern_b0_r = kernel_b0_r->getCode(); + + kern_b0_c = kernel_b0_c->getCode(); + + gemv_s8u8s32_kern = + gemv_s8u8s32_kernel -> generate + (mayiuse(avx512_core_vnni)); + gemv_u8s8s32_kern = + gemv_u8s8s32_kernel -> generate + (mayiuse(avx512_core_vnni)); + }); + + if (arg->bo == 0) { // No need to compute A row sum if bo is zero + if (arg->transa == 0) { + arg->copyA = copyAn; + } else { + arg->copyA = copyAt; + } + } else { + if (arg->transa == 0) { + arg->copyA = copySumAn; + } else { + arg->copyA = copySumAt; + } + } + + if (arg->ao == 0) { // No need to compute B column sum if ao is zero + if (arg->transb == 0) { + arg->copyB = copyBn; + } else { + arg->copyB = copyBt; + } + } else { + if (arg->transb == 0) { + arg->copyB = copySumBn; + } else { + arg->copyB = copySumBt; + } + } + + arg->kernel = kern; + arg->kernel_b = kern_b; + arg->kernel_r = kern_r; + arg->kernel_c = kern_c; + arg->kernel_b0 = kern_b0; + arg->kernel_b0_b = kern_b0_b; + arg->kernel_b0_r = kern_b0_r; + arg->kernel_b0_c = kern_b0_c; + arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern; + arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern; +} + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) +{ + char transa = *transA; + char transb = *transB; + char offsetc = *offsetC; + + blas_t args; + + // Initialize blas structure + args.m = *m; + args.n = *n; + args.k = *k; + args.alpha = alpha; + args.a = a; + args.lda = *lda; + args.b = b; + args.ldb = *ldb; + args.beta = beta; + args.c = c; + args.ldc = *ldc; + args.transa = (transa == 'N' || transa == 'n') ? 0 : 1; + args.transb = (transb == 'N' || transb == 'n') ? 0 : 1; + args.um = 0; + args.un = 0; + args.bm = 0; + args.bn = 0; + args.bk = 0; + args.copyA = NULL; + args.copyB = NULL; + args.kernel = NULL; + args.kernel_b0 = NULL; + args.ao = *oa; + args.bo = *ob; + args.co = oc; + + if (offsetc == 'F' || offsetc == 'f') { + args.offsetc = FIX_OFFSET; + } else if (offsetc == 'R' || offsetc == 'r') { + args.offsetc = ROW_OFFSET; + } else { // offsetc == 'C' || offsetc == 'c' + args.offsetc = COL_OFFSET; + } + + jit_init(&args); + int result = gemm_threading_driver(&args); + + return (result < 0) ? mkldnn_out_of_memory : mkldnn_success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp new file mode 100644 index 0000000000..b2e2902a12 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP +#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP + +#include +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); + +} +} +} + +#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp new file mode 100644 index 0000000000..57554a1852 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp @@ -0,0 +1,539 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" + + +#ifdef _WIN32 +static const bool is_windows = 1; +#else +static const bool is_windows = 0; +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + + + + +// Convert between vector register lengths. +static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); } +static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); } + +// Load from or store to C. +void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst, + const Xbyak::Address &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(make_ymm(dst), src); break; + case 4: vmovups(make_xmm(dst), src); break; + case 2: vmovlps(make_xmm(dst), src); break; + case 1: vmovss(make_xmm(dst), src); break; + } +} +void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst, + const Xbyak::Xmm &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(dst, make_ymm(src)); break; + case 4: vmovups(dst, make_xmm(src)); break; + case 2: vmovsd(dst, make_xmm(src)); break; + case 1: vmovss(dst, make_xmm(src)); break; + } +} + +// Perform length-4 dot product accumulations of unsigned and signed bytes +// in parallel. +// Use vpdpbusd if VNNI available, otherwise emulate. +void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst, + const Xmm &src1, const Xmm &src2) +{ + if (vnni) + vpdpbusd(dst, src1, src2); + else { + vpmaddubsw(dp_scratch, src1, src2); + vpmaddwd(dp_scratch, ones, dp_scratch); + vpaddd(dst, dst, dp_scratch); + } +} + +// Inner kernel. +void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n, + bool cfetch) +{ + int um_vecs = (unroll_m + 15) >> 4; + Label label_kernel_loop; + + L_aligned(label_kernel_loop); { + for (int h = 0; h < 4; h++) { + for (int j = 0; j < unroll_n; j++) { + const Zmm b = b_regs[j & 1]; + + vpbroadcastd(b, ptr[BO + isize * + (2 * j + 2 * h * unroll_n - offset_b)]); + dot_product(c_regs[0][j], b, a_regs[0]); + + if (j == 1 && !(h & 1)) + prefetch_b(ptr[BO + isize * (prefetch_size_b + + 2 * h * unroll_n - offset_b)]); + else if (j % 3 == 0) + prefetch_a(ptr[AO + isize * (prefetch_size_a + + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]); + + for (int i = 1; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + + if (cfetch && (j == std::min(1, unroll_n - 1))) { + if (h == 3) + lea(CO2, ptr[CO2 + LDC]); + else if (h < um_vecs) + prefetch_c(ptr[CO2 + (16 * h * size)]); + } + + if (h == 3 && j == std::min(3, unroll_n - 1)) + lea(AA, ptr[AA + (32 * isize)]); + } + + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * + (32 * i + 2 * (h + 1) * unroll_m - offset_a)]); + + if (h == 2) + prefetch_x(ptr[AA - (offset_a * isize)]); + } + + add(AO, 8 * isize * unroll_m); + add(BO, 8 * isize * unroll_n); + sub(LoopCount, 1); + jg(label_kernel_loop, T_NEAR); + } +} + +// k remainder loop for kernel. +void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m, + int unroll_n, int unroll_k, int bwidth) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + + for (int h = 0; h < unroll_k; h++) { + for (int j = 0; j < unroll_n; j++) { + Zmm b = b_regs[j & 1]; + auto b_src = ptr[BO + (-isize * offset_b + + bwidth * (j + h * unroll_n))]; + + switch (bwidth) { + case 4: + vpbroadcastd(b, b_src); + break; + case 2: + vpbroadcastw(b, b_src); + break; + case 1: + vpbroadcastb(b, b_src); + break; + } + for (int i = 0; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + } + + if (unroll_k > 1) { + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i + + (h + 1) * 2 * unroll_m - offset_a)]); + } + } + + add(AO, unroll_k * unroll_m * bwidth); + add(BO, unroll_k * unroll_n * bwidth); +} + +// Inner loop. +void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + int stage1 = unroll_n, stage2 = unroll_n; + + Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2; + Label label_k_main_loop_3, label_kernel_loop_3; + Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2; + Label label_k_rem_1, label_update_begin; + + mov(AO, A); + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]); + + mov(LoopCount, K); + sar(LoopCount, 4); + jle(label_k_remainder_loop_begin, T_NEAR); + + // Main k loops, broken into three parts to time C prefetching. + sub(LoopCount, stage1 + stage2); + jle(label_k_main_loop_2, T_NEAR); + + kernel_loop(unroll_m, unroll_n, false); + + L_aligned(label_k_main_loop_2); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage1); + jle(label_k_main_loop_3, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + L_aligned(label_k_main_loop_3); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage2); + jle(label_k_remainder_loop_begin, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + // k remainder handling + L_aligned(label_k_remainder_loop_begin); + mov(LoopCount, K); + test(LoopCount, 8); + je(label_k_rem_4, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 2, 4); + + L_aligned(label_k_rem_4); + mov(LoopCount, K); + test(LoopCount, 4); + je(label_k_rem_2, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 1, 4); + + L_aligned(label_k_rem_2); + mov(LoopCount, K); + test(LoopCount, 2); + je(label_k_rem_1, T_NEAR); + + Zmm zero = zmm6; + Zmm tmp = zmm5; + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 2); + + L_aligned(label_k_rem_1); + mov(LoopCount, K); + test(LoopCount, 1); + je(label_update_begin, T_NEAR); + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]); + vpunpcklbw(tmp, a, zero); + vpunpckhbw(a, a, zero); + vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 1); + + // Add offsets and update C. + L_aligned(label_update_begin); + + if (enable_offset_r) { + // Add row offsets. + mov(rax, coffset_ry); + for (int j = 0; j < unroll_n; j++) { + Zmm row_offset = zmm0; + + vbroadcastss(row_offset, ptr[rax + size * j]); + + for (int i = 0; i < um_vecs; i++) + vpaddd(c_regs[i][j], c_regs[i][j], row_offset); + } + add(coffset_ry, size * unroll_n); + } + + if (enable_offset_c) { + // Add column offsets. + mov(rax, coffset_cy); + for (int i = 0; i < um_vecs; i++) { + Zmm col_offset = zmm0; + + c_load(col_offset, ptr[rax + size * 16 * i], unroll_m); + + for (int j = 0; j < unroll_n; j++) + vpaddd(c_regs[i][j], c_regs[i][j], col_offset); + } + } + + Reg64 LDC3 = rax; + lea(LDC3, ptr[LDC + LDC * 2]); + + // C updates. + int c_off_j = 0; + for (int j = 0; j < unroll_n; j++) { + if (j > 0 && (j & 3) == 0) { + lea(CO1, ptr[CO1 + LDC * 4]); + c_off_j += 4; + } + + int jj = j - c_off_j; + + for (int i = 0; i < um_vecs; i++) { + Zmm c = c_regs[i][j]; + Zmm c_old = zmm0; + decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj; + + auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i]; + + if (beta_zero) + c_store(c_mem, c, unroll_m); + else { + c_load(c_old, c_mem, unroll_m); + vpaddd(c_old, c, c_old); + c_store(c_mem, c_old, unroll_m); + } + + vpxorq(c, c, c); + } + } + + lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]); +} + +// Outer loop. +void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y, + Label *&cur_outerloop_label) +{ + Label label_m_loop, label_n_loop, label_n_remainder_loops[6]; + + L(*cur_outerloop_label); + cur_outerloop_label++; + if (unroll_x >= IGEMM_UNROLL_M) { + mov(J, M); + cmp(J, unroll_x); + jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label. + } else { + test(J, unroll_x); + jle(*cur_outerloop_label, T_NEAR); + } + + L_aligned(label_m_loop); { + mov(CO1, C); + add(C, unroll_x * size); + + mov(BO, B); + + mov(AA, K); + imul(AA, AA, unroll_x * isize); + lea(AA, ptr[A + AA + isize * prefetch_size_a]); + + if (enable_offset_c) { + mov(rax, coffset_cx); + mov(coffset_cy, rax); + add(rax, unroll_x * size); + mov(coffset_cx, rax); + } + + if (enable_offset_r) { + mov(rax, coffset_rx); + mov(coffset_ry, rax); + } + + mov(I, N); + cmp(I, unroll_y); + jl(label_n_remainder_loops[0], T_NEAR); + + L_aligned(label_n_loop); { + innerloop(unroll_x, unroll_y); + sub(I, unroll_y); + cmp(I, unroll_y); + jge(label_n_loop, T_NEAR); + } + + align(16); + + int label_idx = 0; + for (int uy = 16; uy > 0; uy >>= 1) { + L(label_n_remainder_loops[label_idx++]); + if (unroll_y > uy) { + test(I, uy); + jle(label_n_remainder_loops[label_idx], T_NEAR); + + innerloop(unroll_x, uy); + align(16); + } + } + L(label_n_remainder_loops[label_idx]); + + mov(A, AO); + if (unroll_x >= IGEMM_UNROLL_M) { + sub(J, unroll_x); + cmp(J, unroll_x); + jge(label_m_loop); + } + } + + align(16); +} + +void jit_avx512_core_gemm_s8u8s32_kern::generate() +{ + // Prologue + preamble(); + sub(rsp, stack_alloc_size); + + if (is_windows) { + mov(A, arg_a); + mov(B, arg_b); + } + + mov(C, arg_c); + mov(LDC, arg_ldc); + + sub(A, -offset_a * isize); + sub(B, -offset_b * isize); + + mov(M, qword[M]); + mov(N, qword[N]); + mov(K, qword[K]); + + lea(LDC, ptr[LDC * size]); + + if (enable_offset_c) { + mov(rax, arg_coffset_c); + mov(coffset_cx, rax); + } + if (enable_offset_r) { + mov(rax, arg_coffset_r); + mov(coffset_rx, rax); + } + + for (int i = 0; i < (max_unroll_m >> 4); i++) { + for (int j = 0; j < max_unroll_n; j++) { + auto &c = c_regs[i][j]; + vpxorq(c, c, c); + } + } + + if (!vnni) { + mov(rax, 1); + movq(make_xmm(ones), rax); + vpbroadcastw(ones, make_xmm(ones)); + } + + Label outerloop_labels[8]; + Label *cur_outerloop_label = &outerloop_labels[0]; + + // Main m loop. + outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label); + + // m remainder loops. + for (int um = 32; um > 0; um >>= 1) + if (IGEMM_UNROLL_M > um) + outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label); + + L(*cur_outerloop_label); + + // Epilogue. + add(rsp, stack_alloc_size); + postamble(); +} + + +jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool + beta_zero_, bool enable_offset_c_, bool enable_offset_r_) : + jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0), + arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0), + coffset_rx(0), coffset_ry(0) +{ + beta_zero = beta_zero_; + enable_offset_c = enable_offset_c_; + enable_offset_r = enable_offset_r_; + vnni = mayiuse(avx512_core_vnni); + + // Assign integer registers + M = is_windows ? rcx : rdi; + N = is_windows ? rdx : rsi; + K = is_windows ? r8 : rdx; + A = is_windows ? rsi : r8; + B = r9; + C = r10; + LDC = r11; + I = r12; + J = r13; + LoopCount = rax; + AO = r14; + BO = r15; + CO1 = rbx; + CO2 = rbp; + AA = is_windows ? rdi : rcx; + + // Assign vector registers + dp_scratch = zmm6; + ones = zmm7; + for (int i = 0; i < (max_unroll_m >> 4); i++) + a_regs[i] = Zmm(i); + b_regs[0] = zmm4; + b_regs[1] = zmm5; + + int rn = 0; + for (int i = 0; i < (max_unroll_m >> 4); i++) + for (int j = 0; j < max_unroll_n; j++) + c_regs[i][j] = Zmm(8 + rn++); + + // Assign stack variables. + stack_alloc_size = 32; + auto args_offset = stack_alloc_size + get_size_of_abi_save_regs() + + 8 + (is_windows ? 48 : 0); + + arg_a = ptr[rsp + (args_offset - 16)]; + arg_b = ptr[rsp + (args_offset - 8)]; + arg_c = ptr[rsp + (args_offset + 0)]; + arg_ldc = ptr[rsp + (args_offset + 8)]; + arg_coffset_c = ptr[rsp + (args_offset + 16)]; + arg_coffset_r = ptr[rsp + (args_offset + 24)]; + + coffset_cx = qword[rsp + 0]; + coffset_cy = qword[rsp + 8]; + coffset_rx = qword[rsp + 16]; + coffset_ry = qword[rsp + 24]; + + generate(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp new file mode 100644 index 0000000000..e8efcc1cc8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef IGEMM_KERNEL_GENERATOR_HPP +#define IGEMM_KERNEL_GENERATOR_HPP + +#include "jit_generator.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator { +public: + jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_, + bool enable_offset_r_); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern); + +protected: + bool beta_zero; + bool enable_offset_c, enable_offset_r; + bool vnni; + + void prefetch_a(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_b(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_c(const Xbyak::Address &src) { + prefetchw(src); + } + void prefetch_x(const Xbyak::Address &src) { + prefetcht0(src); + } + + void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems); + void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems); + + void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1, + const Xbyak::Xmm &src2); + void kernel_loop(int unroll_m, int unroll_n, bool cfetch); + void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth); + void innerloop(int unroll_m, int unroll_n); + void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); + + void generate(); + + +private: + static const int IGEMM_UNROLL_M = 48; + static const int IGEMM_UNROLL_N = 8; + + static const int isize = 2; + static const int size = 4; + + // Prefetch configuration + static const int prefetch_size_a = 32 * 5; + static const int prefetch_size_b = 32 * 4; + + static const int offset_a = 256, offset_b = 256; + static const int max_unroll_m = 48, max_unroll_n = 8; + + // Integer register assignments + Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount; + Xbyak::Reg64 AO, BO, CO1, CO2, AA; + + // Vector register assignments + Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2]; + Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n]; + + // Stack variable assignments + int stack_alloc_size; + Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r; + Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry; + + void L_aligned(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } +}; + +} +} +} + +#endif /* header guard */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp new file mode 100644 index 0000000000..4f0b10dadd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "gemv.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) { + + blas_t arg_gemv = *arg; + + if ((arg -> offsetc == FIX_OFFSET) && // Fix offset + (arg -> ao == 0) && + (arg -> bo == 0) && + (arg -> co[0] == 0) && + (*(arg -> alpha) == 1.0f) && + ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) { + + if (arg -> n == 1) { + + if (arg -> transa == 1) { // A transpose + arg_gemv.n = arg -> k; + arg_gemv.ldc = 1; + arg_gemv.swap = 0; + if (arg -> transb == 0) { // B non transpose + arg_gemv.ldb = 1; + } + // B transpose arg_gemv.ldb = arg -> ldb + gemv_threading_driver(&arg_gemv); + return 1; + } + } + + if (arg -> m == 1) { + + if (arg -> transb == 0) { // B non transpose + arg_gemv.transa = 1; + arg_gemv.m = arg -> n; + arg_gemv.n = arg -> k; + arg_gemv.a = (int8_t *) arg -> b; + arg_gemv.lda = arg -> ldb; + arg_gemv.b = (uint8_t *) arg -> a; + arg_gemv.swap = 1; + if (arg -> transa == 0) { // A non transpose + arg_gemv.ldb = arg -> lda; + } + else { // A transpose + arg_gemv.ldb = 1; + } + gemv_threading_driver(&arg_gemv); + return 1; + } + } + } + + return 0; +} + + +int gemv_kernel_driver(blas_t *arg) { + + dim_t m = arg -> m; + dim_t n = arg -> n; + uint8_t *a = (uint8_t *) arg -> a; + dim_t lda = arg -> lda; + int8_t *b = (int8_t *) arg -> b; + float beta = *(arg -> beta); + + if (arg -> swap) { + arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c); + } + else { + arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a, + arg -> lda, arg -> b, *(arg -> beta), arg -> c); + } + + return 0; +} + +int gemv_threading_driver(blas_t *arg) { + + dim_t nthr_m, nthr_n = 1; + dim_t MB, NB, UM = 16, UN = 64; + dim_t BLOCKM = 192, BLOCKN = 3072; + int status; + dim_t i; + + dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + uint8_t *new_x = NULL; + int32_t *tmp_y = NULL, *new_y = NULL; + + dim_t m = arg -> m, n = arg -> n; + + blas_t arg_seq = *arg; + float zero = 0.0f; + + nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr); + MB = m / nthr_m; + MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM; + nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1; + nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr); + + while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) { + nthr_n++; + } + + NB = n / nthr_n; + NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN; + nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1; + nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m); + + nthr = nthr_m * nthr_n; + + if (arg -> ldb != 1) { + new_x = (uint8_t *)malloc(n, 64); + if (new_x == NULL) + return 1; + for (i = 0; i < n; i++) { + new_x[i] = (arg -> b)[i * arg -> ldb]; + } + arg_seq.b = new_x; + arg_seq.ldb = 1; + } + else new_x = (uint8_t *) arg -> b; + + if (arg -> ldc != 1) { + new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64); + if (new_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + // GEMV computation + if (nthr == 1) { + + if (arg -> ldc != 1) { + if (*(arg -> beta) != 0.0f) { + for (i = 0; i < m; i++) { + new_y[i] = arg -> c[i * arg -> ldc]; + } + } + } + + status = gemv_kernel_driver(&arg_seq); + + if (arg -> ldc != 1) { + for (i = 0; i < m; i++) { + arg -> c[i * arg -> ldc] = new_y[i]; + } + } + + if (arg -> ldb != 1) { + free(new_x); + } + if (arg -> ldc != 1) { + free(new_y); + } + return status; + } + + if (nthr_n > 1) { + tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE); + if (tmp_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + parallel_nd((int) nthr, [&](const dim_t ithr) { + + dim_t m_from, m_to, myM; + dim_t n_from, n_to, myN; + + dim_t n_id, m_id; + dim_t loc_incy = 1; + int32_t *loc_y; + + blas_t arg_loc = arg_seq; + int j; + + m_id = ithr / nthr_n; + n_id = ithr % nthr_n; + + m_from = MB * m_id; + m_to = MB * (m_id + 1); + if ((m_to > m) || (m_id == nthr_m - 1)) + m_to = m; + + myM = m_to - m_from; + + n_from = NB * n_id; + n_to = NB * (n_id + 1); + if ((n_to > n) || (n_id == nthr_n - 1)) + n_to = n; + + myN = n_to - n_from; + + if (n_id != 0) { + arg_loc.beta = &zero; + loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from; + } + else { + if (arg -> ldc == 1) { + loc_y = arg_seq.c + m_from; + } + else { + // need to copy the block of c in new_y + loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t)); + if (*(arg -> beta) != 0.0f) { + for (j = 0; j < myM; j++) { + loc_y[j] = arg -> c[(m_from + j) * arg -> ldc]; + } + } + } + } + + arg_loc.m = myM; + arg_loc.n = myN; + arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from; + arg_loc.b = arg_seq.b + n_from; + arg_loc.c = loc_y; + arg_loc.ldc = loc_incy; + + gemv_kernel_driver(&arg_loc); + + if ((n_id == 0) && (arg -> ldc != 1)) { + for (j = 0; j < myM; j++) { + arg -> c[(m_from + j) * arg -> ldc] = loc_y[j]; + } + } + + }); + + if (nthr_n > 1) { + parallel_nd((int) nthr_m, [&](const dim_t ithr) { + + dim_t j, j_from, j_to, ii; + int32_t acc; + + j_from = MB * ithr; + j_to = MB * (ithr + 1); + if ((j_to > m) || (ithr == nthr - 1)) + j_to = m; + + for (j = j_from; j < j_to; j++) { + acc = 0; + for (ii = 0; ii < nthr_n - 1; ii++) { + acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j]; + } + (arg -> c)[j * arg -> ldc] += acc; + } + }); + free(tmp_y); + } + + if (arg -> ldb != 1) { + free(new_x); + } + + if (arg -> ldc != 1) { + free(new_y); + } + + return 0; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp new file mode 100644 index 0000000000..c57a8c1d12 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp @@ -0,0 +1,411 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" + +#ifdef _WIN32 +#define is_windows 1 +#else +#define is_windows 0 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b, + Xbyak::Zmm a, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, + int use_vnni) { + + if (use_vnni) { + if (swap) + vpdpbusd(acc, a, b); + else + vpdpbusd(acc, b, a); + } + + else { + if (swap) + vpmaddubsw(tmp, a, b); + else + vpmaddubsw(tmp, b, a); + vpmaddwd(tmp, tmp, one); + vpaddd(acc, tmp, acc); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx, + int b_idx, int nreg_acc, + Xbyak::Reg64 A, Xbyak::Reg64 lda, + Xbyak::Reg64 X, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, int use_vnni, + int use_mask, Xbyak::Opmask mask_n) { + + int i; + int nreg_A = nreg_acc / 2 + (nreg_acc % 2); + + // load X + j + if (use_mask) + vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]); + else + vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]); + + xor_(r14, r14); + // load values of A + for (i = 0; i < nreg_A; i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A; i++) { + // vnni (acc, b, a, tmp, one, swap, use_vnni) + vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A, + Xbyak::Zmm B, Xbyak::Zmm C, + Xbyak::Zmm D) { + + vshufi32x4(dest, A, C, 0x44); + vshufi32x4(A, A, C, 0xEE); + vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3 + + vshufi32x4(dest, B, D, 0x44); + vshufi32x4(B, B, D, 0xEE); + vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3 + + vshufi32x4(A, C, D, 0x88); + vshufi32x4(B, C, D, 0xDD); + vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi + +} + +void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y, + int start_a_idx, int start_acc_idx, + Xbyak::Xmm beta, int use_mask, + Xbyak::Opmask mask_m) { + + int l, i, k, j, last_it; + Xbyak::Label store_label; + + l = 0; + for (k = 0; k < nreg_acc; k += 8) { + for (i = 0, j = k; i < 8; i += 4, j += 2) { + if (j < nreg_acc) { + // shuffle per block of 4 registers + shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest + Xbyak::Zmm(start_acc_idx + j), // A = acc0 + Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1 + Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4 + Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5 + + // extract low and high from dest and hadd + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0); + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1); + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l + 1), + Xbyak::Ymm(start_a_idx + l + 2)); + } + l++; + } + + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l - 2), + Xbyak::Ymm(start_a_idx + l - 1)); + + l++; + } + + // eventually add with C and store new value + vxorps(Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx)); + vucomiss(beta, Xbyak::Ymm(start_a_idx)); + je(store_label, T_NEAR); + + // beta = 1 + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + // load Y and add + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]); + else + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]); + + vpaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + k / 8)); + } + + // store + aligned_label(store_label); + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m); + else + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l)); + } + +} + +template +T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) { + + Xbyak::Opmask mask_n = k1, mask_m = k2; + Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label; + Xbyak::Label n_tail_label, update_c_label, end_label; + constexpr unsigned int n_labels = (1 << unroll_m) - 1; + Xbyak::Label m_tail_label_case[n_labels]; + Xbyak::Label n_loop_label_case[n_labels]; + Xbyak::Label n_tail_label_case[n_labels]; + Xbyak::Label update_c_label_case[n_labels]; + + int i, ii; + + Xbyak::Zmm one, tmp; + Xbyak::Reg64 n = abi_param2, m = abi_param1; + Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3; + Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4; + Xbyak::Reg64 X = is_windows ? rdi : r8; + Xbyak::Xmm beta = xmm1; + Xbyak::Reg64 Y = is_windows ? rsi : r9; + + bool swap = !std::is_same::value; + + // Windows: read on the stack lda, X, beta, Y + + int zmm_idx = 1; + int nreg_acc = 1 << unroll_m; + int nreg_A = 1 << (unroll_m - 1); + int nreg_A_acc = nreg_acc + nreg_A; + + if (!use_vnni) { + // set a zmm register to one + tmp = Xbyak::Zmm(0); + one = Xbyak::Zmm(zmm_idx + 1); + zmm_idx += 2; // one + tmp + } + else { + beta = xmm0; + } + + preamble(); + + if (is_windows) { + mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]); + mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]); + movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]); + mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]); + } + + if (use_vnni && !is_windows) { + movaps(beta, xmm1); + } + + mov(rax, (1 << unroll_n) - 1); + kmovq(k3, rax); + + and_(rax, n); // rax contains n & ((1 << unroll_n) - 1) + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_n, rbx); + // mask_n set (AVX512 only), can use rax and rbx again + + // set mask_m for update of the C matrix + // load/store on the C matrix use Ymm so tail according to Ymm size + mov(rax, 7); // 8 * 32 = 256 Ymm size + and_(rax, m); // rax contains m & 7 + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_m, rbx); + // mask_m set (AVX512 only), can use rax and rbx again + + // setup register of ones when VNNI instructions not available + if (!use_vnni) { + vmovdqu16(one, ptr[rip + one_label]); + } + + // M loop + // base pointer for A rax contains a + i * lda + // Loop stop when rax >= a + (m & mask_um) * lda = rbx + // loop increment r10 = um * lda + // rbp = Y + i + mov(rax, A); // i = 0 + mov(rbx, m); + and_(rbx, mask_um); + imul(rbx, lda); + add(rbx, A); + mov(r10, lda); + sal(r10, unroll_m); + mov(rbp, Y); + + // N loop + // base pointer for X r11 contains x + j + // Loop stop when r11 >= x + n & mask_un = r12 + // loop increment un + // r13 = rax + j = A + i * lda + j + mov(r12, n); + and_(r12, mask_un); + add(r12, X); + + // M loop + aligned_label(m_loop_label); + cmp(rax, rbx); + jge(m_tail_label, T_NEAR); + + // enter M loop + for(i = 0; i < nreg_acc; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label); + cmp(r11, r12); + jge(n_tail_label, T_NEAR); + + // enter N loop + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label, T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label); + + ktestq(mask_n, k3); + je(update_c_label, T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label); + + update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m); + + // increment rax with um * lda + add(rax, r10); + add(rbp, 1 << (unroll_m + 2)); + jmp(m_loop_label, T_NEAR); + // end M loop + + // M tail + aligned_label(m_tail_label); + + // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1 + mov(r10, m); + and_(r10, (1 << unroll_m) - 1); + for (ii = 1; ii < 1 << unroll_m; ii++) { + aligned_label(m_tail_label_case[ii-1]); + cmp(r10, ii); + if (ii == (1 << unroll_m) - 1) + jne(end_label, T_NEAR); + else + jne(m_tail_label_case[ii], T_NEAR); + + // m_tail = i, use i accumulators + + for(i = 0; i < ii; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label_case[ii - 1]); + cmp(r11, r12); + jge(n_tail_label_case[ii - 1], T_NEAR); + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label_case[ii - 1], T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label_case[ii - 1]); + ktestq(mask_n, k3); + je(update_c_label_case[ii - 1], T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label_case[ii - 1]); + update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m); + + if (ii < ((1 << unroll_m) - 1)) + jmp(end_label, T_NEAR); + } + + aligned_label(end_label); + + postamble(); + + if (!use_vnni) { + aligned_label(one_label); + for (i = 0; i < size_vec_reg/8; i++) + dq(0x0001000100010001); + } + + return (T) getCode(); +} + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate(int); + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate(int); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp new file mode 100644 index 0000000000..9ea23a5f56 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemv_s8u8s32_kern : jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern); + + // assumes untoll_{m,n} are a power of 2 + static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m + const int mask_um = 0xFFFFFFF0; + static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n + const int mask_un = 0xFFFFFFC0; + const int size_vec_reg = 64; // bytes + + void aligned_label(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } + + void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int); + void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64, + Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask); + void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm); + void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask); + +public: + jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {}; + + // m, n, alpha, a, lda, x, beta, y + typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + template + T generate(int use_vnni); + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp new file mode 100644 index 0000000000..544cd2ff25 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp @@ -0,0 +1,819 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l170; +Xbyak::Label l1f0; +Xbyak::Label l20; +Xbyak::Label l224; +Xbyak::Label l234; +Xbyak::Label l240; +Xbyak::Label l254; +Xbyak::Label l32c; +Xbyak::Label l34; +Xbyak::Label l388; +Xbyak::Label l3b0; +Xbyak::Label l3c0; +Xbyak::Label l3cc; +Xbyak::Label l3dc; +Xbyak::Label l454; +Xbyak::Label l48c; +Xbyak::Label l4a8; +Xbyak::Label l4b8; +Xbyak::Label l4c4; +Xbyak::Label l4d8; +Xbyak::Label l570; +Xbyak::Label l5c4; +Xbyak::Label l5f0; +Xbyak::Label l60c; +Xbyak::Label l61c; +Xbyak::Label l628; +Xbyak::Label l638; +Xbyak::Label l6b0; +Xbyak::Label l6f4; +Xbyak::Label l720; +Xbyak::Label l73c; +Xbyak::Label l74c; +Xbyak::Label l758; +Xbyak::Label l76c; +Xbyak::Label l804; +Xbyak::Label l858; +Xbyak::Label l88c; +Xbyak::Label l8a4; +Xbyak::Label l8b2; +Xbyak::Label l8bc; +Xbyak::Label l8cc; +Xbyak::Label l944; +Xbyak::Label l98c; +Xbyak::Label l9b0; +Xbyak::Label l9c8; +Xbyak::Label l9d8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l234, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + mov(I, M); + sar(I, 0x2); + jle(l170, T_NEAR); + align(4); + +L(l34); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + movdqu(xmm0, xword[A1-0x60]); + movdqu(xmm1, xword[A1+LDA*1-0x60]); + movdqu(xmm2, xword[A1+LDA*2-0x60]); + movdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x30], xmm2); + sub(B, -192); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(l170); + test(M, 0x2); + jle(l1f0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + movdqu(xmm4, xword[A1-0x70]); + movdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + movdqa(xmm6, xmm0); + punpcklbw(xmm0, xmm3); + punpckhbw(xmm6, xmm3); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm6); + movdqa(xmm6, xmm1); + punpcklbw(xmm1, xmm4); + punpckhbw(xmm6, xmm4); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm6); + movdqa(xmm6, xmm2); + punpcklbw(xmm2, xmm5); + punpckhbw(xmm6, xmm5); + movdqu(xword[B-0x40], xmm2); + movdqu(xword[B-0x30], xmm6); + sub(B, -96); + align(4); + +L(l1f0); + test(M, 0x1); + jle(l224, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l224); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(l234); + cmp(N, 0x20); + jl(l3c0, T_NEAR); + align(4); + +L(l240); + mov(A1, A); + add(A, 0x20); + mov(I, M); + sar(I, 0x2); + jle(l32c, T_NEAR); + align(4); + +L(l254); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l254, T_NEAR); + align(4); + +L(l32c); + test(M, 0x2); + jle(l388, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l388); + test(M, 0x1); + jle(l3b0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3b0); + sub(N, 0x20); + cmp(N, 0x20); + jge(l240, T_NEAR); + align(4); + +L(l3c0); + cmp(N, 0x10); + jl(l4b8, T_NEAR); + align(4); + +L(l3cc); + mov(A1, A); + add(A, 0x10); + mov(I, M); + sar(I, 0x2); + jle(l454, T_NEAR); + align(4); + +L(l3dc); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l3dc, T_NEAR); + align(4); + +L(l454); + test(M, 0x2); + jle(l48c, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(l48c); + test(M, 0x1); + jle(l4a8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l4a8); + sub(N, 0x10); + cmp(N, 0x10); + jge(l3cc, T_NEAR); + align(4); + +L(l4b8); + cmp(N, 0x8); + jl(l61c, T_NEAR); + align(4); + +L(l4c4); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(l570, T_NEAR); + align(4); + +L(l4d8); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l4d8, T_NEAR); + align(4); + +L(l570); + test(M, 0x4); + jle(l5c4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l5c4); + test(M, 0x2); + jle(l5f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l5f0); + test(M, 0x1); + jle(l60c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l60c); + sub(N, 0x8); + cmp(N, 0x8); + jge(l4c4, T_NEAR); + align(4); + +L(l61c); + cmp(N, 0x4); + jl(l74c, T_NEAR); + align(4); + +L(l628); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l6b0, T_NEAR); + align(4); + +L(l638); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l638, T_NEAR); + align(4); + +L(l6b0); + test(M, 0x4); + jle(l6f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l6f4); + test(M, 0x2); + jle(l720, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l720); + test(M, 0x1); + jle(l73c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l73c); + sub(N, 0x4); + cmp(N, 0x4); + jge(l628, T_NEAR); + align(4); + +L(l74c); + cmp(N, 0x2); + jl(l8b2, T_NEAR); + align(4); + +L(l758); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l804, T_NEAR); + align(4); + +L(l76c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l76c, T_NEAR); + align(4); + +L(l804); + test(M, 0x4); + jle(l858, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l858); + test(M, 0x2); + jle(l88c, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l88c); + test(M, 0x1); + jle(l8a4, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l8a4); + sub(N, 0x2); + cmp(N, 0x2); + jge(l758, T_NEAR); + align(4); + +L(l8b2); + cmp(N, 0x1); + jl(l9d8, T_NEAR); + align(4); + +L(l8bc); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l944, T_NEAR); + align(4); + +L(l8cc); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l8cc, T_NEAR); + align(4); + +L(l944); + test(M, 0x4); + jle(l98c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l98c); + test(M, 0x2); + jle(l9b0, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l9b0); + test(M, 0x1); + jle(l9c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l9c8); + sub(N, 0x1); + cmp(N, 0x1); + jge(l8bc, T_NEAR); + align(4); + +L(l9d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp new file mode 100644 index 0000000000..1c11fc6cef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp @@ -0,0 +1,2209 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1014; +Xbyak::Label l1390; +Xbyak::Label l159c; +Xbyak::Label l173c; +Xbyak::Label l18e4; +Xbyak::Label l1a7c; +Xbyak::Label l1a8c; +Xbyak::Label l1a98; +Xbyak::Label l1ab4; +Xbyak::Label l1c64; +Xbyak::Label l1d74; +Xbyak::Label l1e50; +Xbyak::Label l1f2c; +Xbyak::Label l1ffc; +Xbyak::Label l20; +Xbyak::Label l200c; +Xbyak::Label l2018; +Xbyak::Label l2034; +Xbyak::Label l2110; +Xbyak::Label l21a0; +Xbyak::Label l2210; +Xbyak::Label l2284; +Xbyak::Label l22f0; +Xbyak::Label l2300; +Xbyak::Label l230c; +Xbyak::Label l2324; +Xbyak::Label l2398; +Xbyak::Label l23e8; +Xbyak::Label l242c; +Xbyak::Label l2474; +Xbyak::Label l24b4; +Xbyak::Label l24c4; +Xbyak::Label l24d0; +Xbyak::Label l24e8; +Xbyak::Label l2520; +Xbyak::Label l254c; +Xbyak::Label l2578; +Xbyak::Label l25a8; +Xbyak::Label l25c8; +Xbyak::Label l25d6; +Xbyak::Label l25e0; +Xbyak::Label l25f0; +Xbyak::Label l260c; +Xbyak::Label l262c; +Xbyak::Label l264c; +Xbyak::Label l2668; +Xbyak::Label l2680; +Xbyak::Label l2690; +Xbyak::Label l44; +Xbyak::Label l58c; +Xbyak::Label l8b0; +Xbyak::Label lb14; +Xbyak::Label ld84; +Xbyak::Label lfdc; +Xbyak::Label lfec; +Xbyak::Label lff8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lfec, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l58c, T_NEAR); + align(4); + +L(l44); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0x100], xmm4); + movdqu(xword[B+0x1c0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0x110], xmm4); + movdqu(xword[B+0x1d0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0x120], xmm4); + movdqu(xword[B+0x1e0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0x130], xmm4); + movdqu(xword[B+0x1f0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movdqu(xword[B+0x140], xmm4); + movdqu(xword[B+0x200], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movdqu(xword[B+0x150], xmm4); + movdqu(xword[B+0x210], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movdqu(xword[B+0x160], xmm4); + movdqu(xword[B+0x220], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movdqu(xword[B+0x170], xmm4); + movdqu(xword[B+0x230], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movdqu(xword[B+0x180], xmm4); + movdqu(xword[B+0x240], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movdqu(xword[B+0x190], xmm4); + movdqu(xword[B+0x250], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movdqu(xword[B+0x1a0], xmm4); + movdqu(xword[B+0x260], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + movdqu(xword[B+0x1b0], xmm4); + movdqu(xword[B+0x270], xmm3); + sub(A1, -16); + sub(B, -768); + dec(I); + jg(l44, T_NEAR); + align(4); + +L(l58c); + test(M, 0x8); + jle(l8b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + sub(A1, -8); + sub(B, -384); + align(4); + +L(l8b0); + test(M, 0x4); + jle(lb14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x30], xmm0); + sub(A1, -4); + sub(B, -192); + align(4); + +L(lb14); + test(M, 0x2); + jle(ld84, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(ld84); + test(M, 0x1); + jle(lfdc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lfdc); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(lfec); + cmp(N, 0x20); + jl(l1a8c, T_NEAR); + align(4); + +L(lff8); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1390, T_NEAR); + align(4); + +L(l1014); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movdqu(xword[B+0x80], xmm4); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x90], xmm4); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movdqu(xword[B+0xa0], xmm4); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movdqu(xword[B+0xb0], xmm4); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0xc0], xmm4); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0xd0], xmm4); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0xe0], xmm4); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0xf0], xmm4); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(l1014, T_NEAR); + align(4); + +L(l1390); + test(M, 0x8); + jle(l159c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l159c); + test(M, 0x4); + jle(l173c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l173c); + test(M, 0x2); + jle(l18e4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l18e4); + test(M, 0x1); + jle(l1a7c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l1a7c); + sub(N, 0x20); + cmp(N, 0x20); + jge(lff8, T_NEAR); + align(4); + +L(l1a8c); + cmp(N, 0x10); + jl(l200c, T_NEAR); + align(4); + +L(l1a98); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1c64, T_NEAR); + align(4); + +L(l1ab4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movdqu(xword[B], xmm4); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B+0x10], xmm4); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + movdqu(xword[B+0x30], xmm4); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l1ab4, T_NEAR); + align(4); + +L(l1c64); + test(M, 0x8); + jle(l1d74, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l1d74); + test(M, 0x4); + jle(l1e50, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l1e50); + test(M, 0x2); + jle(l1f2c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l1f2c); + test(M, 0x1); + jle(l1ffc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l1ffc); + sub(N, 0x10); + cmp(N, 0x10); + jge(l1a98, T_NEAR); + align(4); + +L(l200c); + cmp(N, 0x8); + jl(l2300, T_NEAR); + align(4); + +L(l2018); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2110, T_NEAR); + align(4); + +L(l2034); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2034, T_NEAR); + align(4); + +L(l2110); + test(M, 0x8); + jle(l21a0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l21a0); + test(M, 0x4); + jle(l2210, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l2210); + test(M, 0x2); + jle(l2284, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2284); + test(M, 0x1); + jle(l22f0, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l22f0); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2018, T_NEAR); + align(4); + +L(l2300); + cmp(N, 0x4); + jl(l24c4, T_NEAR); + align(4); + +L(l230c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2398, T_NEAR); + align(4); + +L(l2324); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l2324, T_NEAR); + align(4); + +L(l2398); + test(M, 0x8); + jle(l23e8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l23e8); + test(M, 0x4); + jle(l242c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l242c); + test(M, 0x2); + jle(l2474, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2474); + test(M, 0x1); + jle(l24b4, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l24b4); + sub(N, 0x4); + cmp(N, 0x4); + jge(l230c, T_NEAR); + align(4); + +L(l24c4); + cmp(N, 0x2); + jl(l25d6, T_NEAR); + align(4); + +L(l24d0); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2520, T_NEAR); + align(4); + +L(l24e8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l24e8, T_NEAR); + align(4); + +L(l2520); + test(M, 0x8); + jle(l254c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l254c); + test(M, 0x4); + jle(l2578, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2578); + test(M, 0x2); + jle(l25a8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l25a8); + test(M, 0x1); + jle(l25c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l25c8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l24d0, T_NEAR); + align(4); + +L(l25d6); + cmp(N, 0x1); + jl(l2690, T_NEAR); + align(4); + +L(l25e0); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l260c, T_NEAR); + align(4); + +L(l25f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l25f0, T_NEAR); + align(4); + +L(l260c); + test(M, 0x8); + jle(l262c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l262c); + test(M, 0x4); + jle(l264c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l264c); + test(M, 0x2); + jle(l2668, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l2668); + test(M, 0x1); + jle(l2680, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l2680); + sub(N, 0x1); + cmp(N, 0x1); + jge(l25e0, T_NEAR); + align(4); + +L(l2690); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp new file mode 100644 index 0000000000..56c36ee14a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp @@ -0,0 +1,564 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l118; +Xbyak::Label l1a8; +Xbyak::Label l20; +Xbyak::Label l218; +Xbyak::Label l28c; +Xbyak::Label l2f8; +Xbyak::Label l308; +Xbyak::Label l314; +Xbyak::Label l32c; +Xbyak::Label l3a0; +Xbyak::Label l3c; +Xbyak::Label l3f0; +Xbyak::Label l434; +Xbyak::Label l47c; +Xbyak::Label l4bc; +Xbyak::Label l4cc; +Xbyak::Label l4d8; +Xbyak::Label l4f0; +Xbyak::Label l528; +Xbyak::Label l554; +Xbyak::Label l580; +Xbyak::Label l5b0; +Xbyak::Label l5d0; +Xbyak::Label l5de; +Xbyak::Label l5e8; +Xbyak::Label l5f8; +Xbyak::Label l614; +Xbyak::Label l634; +Xbyak::Label l654; +Xbyak::Label l670; +Xbyak::Label l688; +Xbyak::Label l698; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l308, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l118, T_NEAR); + align(4); + +L(l3c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l3c, T_NEAR); + align(4); + +L(l118); + test(M, 0x8); + jle(l1a8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l1a8); + test(M, 0x4); + jle(l218, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l218); + test(M, 0x2); + jle(l28c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l28c); + test(M, 0x1); + jle(l2f8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2f8); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l308); + cmp(N, 0x4); + jl(l4cc, T_NEAR); + align(4); + +L(l314); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l3a0, T_NEAR); + align(4); + +L(l32c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l32c, T_NEAR); + align(4); + +L(l3a0); + test(M, 0x8); + jle(l3f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3f0); + test(M, 0x4); + jle(l434, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l434); + test(M, 0x2); + jle(l47c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l47c); + test(M, 0x1); + jle(l4bc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4bc); + sub(N, 0x4); + cmp(N, 0x4); + jge(l314, T_NEAR); + align(4); + +L(l4cc); + cmp(N, 0x2); + jl(l5de, T_NEAR); + align(4); + +L(l4d8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l528, T_NEAR); + align(4); + +L(l4f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l4f0, T_NEAR); + align(4); + +L(l528); + test(M, 0x8); + jle(l554, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l554); + test(M, 0x4); + jle(l580, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l580); + test(M, 0x2); + jle(l5b0, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l5b0); + test(M, 0x1); + jle(l5d0, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l5d0); + sub(N, 0x2); + cmp(N, 0x2); + jge(l4d8, T_NEAR); + align(4); + +L(l5de); + cmp(N, 0x1); + jl(l698, T_NEAR); + align(4); + +L(l5e8); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l614, T_NEAR); + align(4); + +L(l5f8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l5f8, T_NEAR); + align(4); + +L(l614); + test(M, 0x8); + jle(l634, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l634); + test(M, 0x4); + jle(l654, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l654); + test(M, 0x2); + jle(l670, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l670); + test(M, 0x1); + jle(l688, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l688); + sub(N, 0x1); + cmp(N, 0x1); + jge(l5e8, T_NEAR); + align(4); + +L(l698); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp new file mode 100644 index 0000000000..53e99d94de --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp @@ -0,0 +1,501 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l120; +Xbyak::Label l14c; +Xbyak::Label l168; +Xbyak::Label l178; +Xbyak::Label l184; +Xbyak::Label l194; +Xbyak::Label l20; +Xbyak::Label l20c; +Xbyak::Label l250; +Xbyak::Label l27c; +Xbyak::Label l298; +Xbyak::Label l2a8; +Xbyak::Label l2b4; +Xbyak::Label l2c8; +Xbyak::Label l34; +Xbyak::Label l360; +Xbyak::Label l3b4; +Xbyak::Label l3e8; +Xbyak::Label l400; +Xbyak::Label l40e; +Xbyak::Label l418; +Xbyak::Label l428; +Xbyak::Label l4a0; +Xbyak::Label l4e8; +Xbyak::Label l50c; +Xbyak::Label l524; +Xbyak::Label l534; +Xbyak::Label lcc; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l178, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(lcc, T_NEAR); + align(4); + +L(l34); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(lcc); + test(M, 0x4); + jle(l120, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l120); + test(M, 0x2); + jle(l14c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l14c); + test(M, 0x1); + jle(l168, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l168); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l178); + cmp(N, 0x4); + jl(l2a8, T_NEAR); + align(4); + +L(l184); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l20c, T_NEAR); + align(4); + +L(l194); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l194, T_NEAR); + align(4); + +L(l20c); + test(M, 0x4); + jle(l250, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l250); + test(M, 0x2); + jle(l27c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l27c); + test(M, 0x1); + jle(l298, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l298); + sub(N, 0x4); + cmp(N, 0x4); + jge(l184, T_NEAR); + align(4); + +L(l2a8); + cmp(N, 0x2); + jl(l40e, T_NEAR); + align(4); + +L(l2b4); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l360, T_NEAR); + align(4); + +L(l2c8); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l2c8, T_NEAR); + align(4); + +L(l360); + test(M, 0x4); + jle(l3b4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3b4); + test(M, 0x2); + jle(l3e8, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3e8); + test(M, 0x1); + jle(l400, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l400); + sub(N, 0x2); + cmp(N, 0x2); + jge(l2b4, T_NEAR); + align(4); + +L(l40e); + cmp(N, 0x1); + jl(l534, T_NEAR); + align(4); + +L(l418); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l4a0, T_NEAR); + align(4); + +L(l428); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l428, T_NEAR); + align(4); + +L(l4a0); + test(M, 0x4); + jle(l4e8, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4e8); + test(M, 0x2); + jle(l50c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l50c); + test(M, 0x1); + jle(l524, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l524); + sub(N, 0x1); + cmp(N, 0x1); + jge(l418, T_NEAR); + align(4); + +L(l534); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp new file mode 100644 index 0000000000..49a312fc88 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp @@ -0,0 +1,1283 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1024; +Xbyak::Label l1090; +Xbyak::Label l10d4; +Xbyak::Label l10fc; +Xbyak::Label l111a; +Xbyak::Label l1124; +Xbyak::Label l113c; +Xbyak::Label l11d4; +Xbyak::Label l1234; +Xbyak::Label l1278; +Xbyak::Label l129c; +Xbyak::Label l12bc; +Xbyak::Label l20; +Xbyak::Label l2a0; +Xbyak::Label l3c0; +Xbyak::Label l438; +Xbyak::Label l480; +Xbyak::Label l48c; +Xbyak::Label l4c8; +Xbyak::Label l5c; +Xbyak::Label l6a8; +Xbyak::Label l7b4; +Xbyak::Label l850; +Xbyak::Label l89c; +Xbyak::Label l8a8; +Xbyak::Label l8d0; +Xbyak::Label l9d0; +Xbyak::Label la64; +Xbyak::Label lab8; +Xbyak::Label lae8; +Xbyak::Label laf4; +Xbyak::Label lb14; +Xbyak::Label lc30; +Xbyak::Label lcc8; +Xbyak::Label ld1c; +Xbyak::Label ld54; +Xbyak::Label ld78; +Xbyak::Label ld84; +Xbyak::Label ld9c; +Xbyak::Label le58; +Xbyak::Label lebc; +Xbyak::Label lef8; +Xbyak::Label lf1c; +Xbyak::Label lf3c; +Xbyak::Label lf48; +Xbyak::Label lf60; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l480, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x2); + jle(l2a0, T_NEAR); + align(4); + +L(l5c); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1+LDA*1-0x80]); + vmovdqu(xmm2, xword[A1+LDA*2-0x80]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x80]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x60], xmm2); + vmovdqu(xword[B-0x50], xmm3); + vmovdqu(xmm0, xword[A1-0x70]); + vmovdqu(xmm1, xword[A1+LDA*1-0x70]); + vmovdqu(xmm2, xword[A1+LDA*2-0x70]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x70]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x20], xmm2); + vmovdqu(xword[B-0x10], xmm3); + vmovdqu(xmm0, xword[A1-0x60]); + vmovdqu(xmm1, xword[A1+LDA*1-0x60]); + vmovdqu(xmm2, xword[A1+LDA*2-0x60]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B+0x20], xmm2); + vmovdqu(xword[B+0x30], xmm3); + sub(B, -192); + dec(I); + jg(l5c, T_NEAR); + align(4); + +L(l2a0); + test(M, 0x2); + jle(l3c0, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vmovdqu(xmm6, xword[A1-0x80]); + vmovdqu(xmm4, xword[A1-0x70]); + vmovdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + vpunpcklbw(xmm3, xmm0, xmm6); + vpunpckhbw(xmm0, xmm0, xmm6); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm8, ymm8, ymm7); + vmovdqu(xword[B-0x80], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x70], xmm0); + vpunpcklbw(xmm3, xmm1, xmm4); + vpunpckhbw(xmm0, xmm1, xmm4); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm10, ymm10, ymm7); + vmovdqu(xword[B-0x60], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x50], xmm0); + vpunpcklbw(xmm3, xmm2, xmm5); + vpunpckhbw(xmm0, xmm2, xmm5); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm12, ymm12, ymm7); + vmovdqu(xword[B-0x40], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x30], xmm0); + sub(B, -96); + align(4); + +L(l3c0); + test(M, 0x1); + jle(l438, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + vpmovsxbd(ymm7, xmm1); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbd(ymm7, xmm2); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l438); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(l480); + cmp(N, 0x20); + jl(l89c, T_NEAR); + align(4); + +L(l48c); + mov(A1, A); + add(A, 0x20); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x2); + jle(l6a8, T_NEAR); + align(4); + +L(l4c8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l4c8, T_NEAR); + align(4); + +L(l6a8); + test(M, 0x2); + jle(l7b4, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + pmovsxbw(xmm5, xmm1); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l850, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbd(xmm5, xmm1); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm1, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm1, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm1, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l850); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(l48c, T_NEAR); + align(4); + +L(l89c); + cmp(N, 0x10); + jl(lae8, T_NEAR); + align(4); + +L(l8a8); + mov(A1, A); + add(A, 0x10); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x2); + jle(l9d0, T_NEAR); + align(4); + +L(l8d0); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l8d0, T_NEAR); + align(4); + +L(l9d0); + test(M, 0x2); + jle(la64, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + pmovsxbw(xmm5, xmm2); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(la64); + test(M, 0x1); + jle(lab8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lab8); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l8a8, T_NEAR); + align(4); + +L(lae8); + cmp(N, 0x8); + jl(ld78, T_NEAR); + align(4); + +L(laf4); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(lc30, T_NEAR); + align(4); + +L(lb14); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(lb14, T_NEAR); + align(4); + +L(lc30); + test(M, 0x4); + jle(lcc8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(lcc8); + test(M, 0x2); + jle(ld1c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(ld1c); + test(M, 0x1); + jle(ld54, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(ld54); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(laf4, T_NEAR); + align(4); + +L(ld78); + cmp(N, 0x4); + jl(lf3c, T_NEAR); + align(4); + +L(ld84); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(le58, T_NEAR); + align(4); + +L(ld9c); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(ld9c, T_NEAR); + align(4); + +L(le58); + test(M, 0x4); + jle(lebc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lebc); + test(M, 0x2); + jle(lef8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(lef8); + test(M, 0x1); + jle(lf1c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(ld84, T_NEAR); + align(4); + +L(lf3c); + cmp(N, 0x2); + jl(l111a, T_NEAR); + align(4); + +L(lf48); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l1024, T_NEAR); + align(4); + +L(lf60); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(lf60, T_NEAR); + align(4); + +L(l1024); + test(M, 0x4); + jle(l1090, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l1090); + test(M, 0x2); + jle(l10d4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l10d4); + test(M, 0x1); + jle(l10fc, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l10fc); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(lf48, T_NEAR); + align(4); + +L(l111a); + cmp(N, 0x1); + jl(l12bc, T_NEAR); + align(4); + +L(l1124); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l11d4, T_NEAR); + align(4); + +L(l113c); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l113c, T_NEAR); + align(4); + +L(l11d4); + test(M, 0x4); + jle(l1234, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l1234); + test(M, 0x2); + jle(l1278, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l1278); + test(M, 0x1); + jle(l129c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l129c); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l1124, T_NEAR); + align(4); + +L(l12bc); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp new file mode 100644 index 0000000000..a4f4ff09c6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp @@ -0,0 +1,3163 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1750; +Xbyak::Label l1b6c; +Xbyak::Label l1e14; +Xbyak::Label l20; +Xbyak::Label l2068; +Xbyak::Label l226c; +Xbyak::Label l22b8; +Xbyak::Label l22c4; +Xbyak::Label l22f4; +Xbyak::Label l26b4; +Xbyak::Label l28cc; +Xbyak::Label l2a2c; +Xbyak::Label l2b5c; +Xbyak::Label l2c64; +Xbyak::Label l2c94; +Xbyak::Label l2ca0; +Xbyak::Label l2cc8; +Xbyak::Label l2eac; +Xbyak::Label l2fc0; +Xbyak::Label l3078; +Xbyak::Label l3118; +Xbyak::Label l319c; +Xbyak::Label l31c0; +Xbyak::Label l31cc; +Xbyak::Label l31ec; +Xbyak::Label l32e4; +Xbyak::Label l3378; +Xbyak::Label l33dc; +Xbyak::Label l3434; +Xbyak::Label l347c; +Xbyak::Label l349c; +Xbyak::Label l34a8; +Xbyak::Label l34c8; +Xbyak::Label l3558; +Xbyak::Label l35b0; +Xbyak::Label l35f4; +Xbyak::Label l3638; +Xbyak::Label l366c; +Xbyak::Label l368a; +Xbyak::Label l3694; +Xbyak::Label l36a8; +Xbyak::Label l36ec; +Xbyak::Label l3728; +Xbyak::Label l3760; +Xbyak::Label l3794; +Xbyak::Label l37b8; +Xbyak::Label l37d8; +Xbyak::Label l5cc; +Xbyak::Label l6c; +Xbyak::Label l968; +Xbyak::Label lc80; +Xbyak::Label lf1c; +Xbyak::Label lf64; +Xbyak::Label lf70; +Xbyak::Label lfb4; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lf64, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x3); + jle(l5cc, T_NEAR); + align(4); + +L(l6c); + vmovq(xmm0, qword[A1-0x80]); + vmovq(xmm1, qword[A1+LDA*1-0x80]); + vmovq(xmm2, qword[A1+LDA*2-0x80]); + vmovq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B+0x40], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x70], xmm2); + vmovdqu(xword[B+0x50], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x60], xmm0); + vmovdqu(xword[B+0x60], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x50], xmm2); + vmovdqu(xword[B+0x70], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B+0x80], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x30], xmm2); + vmovdqu(xword[B+0x90], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x20], xmm0); + vmovdqu(xword[B+0xa0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x10], xmm2); + vmovdqu(xword[B+0xb0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0xc0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x10], xmm2); + vmovdqu(xword[B+0xd0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x20], xmm0); + vmovdqu(xword[B+0xe0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x30], xmm2); + vmovdqu(xword[B+0xf0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -8); + sub(B, -384); + dec(I); + jg(l6c, T_NEAR); + align(4); + +L(l5cc); + test(M, 0x4); + jle(l968, T_NEAR); + vmovd(xmm0, dword[A1-0x80]); + vmovd(xmm1, dword[A1+LDA*1-0x80]); + vmovd(xmm2, dword[A1+LDA*2-0x80]); + vmovd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x80], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x60], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x50], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x40], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B+0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -4); + sub(B, -192); + align(4); + +L(l968); + test(M, 0x2); + jle(lc80, T_NEAR); + mov(ax, word[A1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(lc80); + test(M, 0x1); + jle(lf1c, T_NEAR); + mov(al, byte[A1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(lf64); + cmp(N, 0x20); + jl(l22b8, T_NEAR); + align(4); + +L(lf70); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x4); + jle(l1750, T_NEAR); + align(4); + +L(lfb4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x80], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x90], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0xa0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0xb0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0xc0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0xd0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0xe0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0xf0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(lfb4, T_NEAR); + align(4); + +L(l1750); + test(M, 0x8); + jle(l1b6c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l1b6c); + test(M, 0x4); + jle(l1e14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l1e14); + test(M, 0x2); + jle(l2068, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l2068); + test(M, 0x1); + jle(l226c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l226c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(lf70, T_NEAR); + align(4); + +L(l22b8); + cmp(N, 0x10); + jl(l2c94, T_NEAR); + align(4); + +L(l22c4); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x4); + jle(l26b4, T_NEAR); + align(4); + +L(l22f4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l22f4, T_NEAR); + align(4); + +L(l26b4); + test(M, 0x8); + jle(l28cc, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l28cc); + test(M, 0x4); + jle(l2a2c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l2a2c); + test(M, 0x2); + jle(l2b5c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l2b5c); + test(M, 0x1); + jle(l2c64, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2c64); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l22c4, T_NEAR); + align(4); + +L(l2c94); + cmp(N, 0x8); + jl(l31c0, T_NEAR); + align(4); + +L(l2ca0); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l2eac, T_NEAR); + align(4); + +L(l2cc8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2cc8, T_NEAR); + align(4); + +L(l2eac); + test(M, 0x8); + jle(l2fc0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l2fc0); + test(M, 0x4); + jle(l3078, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3078); + test(M, 0x2); + jle(l3118, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3118); + test(M, 0x1); + jle(l319c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l319c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2ca0, T_NEAR); + align(4); + +L(l31c0); + cmp(N, 0x4); + jl(l349c, T_NEAR); + align(4); + +L(l31cc); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l32e4, T_NEAR); + align(4); + +L(l31ec); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l31ec, T_NEAR); + align(4); + +L(l32e4); + test(M, 0x8); + jle(l3378, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3378); + test(M, 0x4); + jle(l33dc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l33dc); + test(M, 0x2); + jle(l3434, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3434); + test(M, 0x1); + jle(l347c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l347c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l31cc, T_NEAR); + align(4); + +L(l349c); + cmp(N, 0x2); + jl(l368a, T_NEAR); + align(4); + +L(l34a8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l3558, T_NEAR); + align(4); + +L(l34c8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l34c8, T_NEAR); + align(4); + +L(l3558); + test(M, 0x8); + jle(l35b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l35b0); + test(M, 0x4); + jle(l35f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l35f4); + test(M, 0x2); + jle(l3638, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3638); + test(M, 0x1); + jle(l366c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l366c); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l34a8, T_NEAR); + align(4); + +L(l368a); + cmp(N, 0x1); + jl(l37d8, T_NEAR); + align(4); + +L(l3694); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l36ec, T_NEAR); + align(4); + +L(l36a8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l36a8, T_NEAR); + align(4); + +L(l36ec); + test(M, 0x8); + jle(l3728, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3728); + test(M, 0x4); + jle(l3760, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3760); + test(M, 0x2); + jle(l3794, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l3794); + test(M, 0x1); + jle(l37b8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l37b8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l3694, T_NEAR); + align(4); + +L(l37d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp new file mode 100644 index 0000000000..c7f1393c9d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp @@ -0,0 +1,821 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l20; +Xbyak::Label l22c; +Xbyak::Label l340; +Xbyak::Label l3f8; +Xbyak::Label l48; +Xbyak::Label l498; +Xbyak::Label l51c; +Xbyak::Label l540; +Xbyak::Label l54c; +Xbyak::Label l56c; +Xbyak::Label l664; +Xbyak::Label l6f8; +Xbyak::Label l75c; +Xbyak::Label l7b4; +Xbyak::Label l7fc; +Xbyak::Label l81c; +Xbyak::Label l828; +Xbyak::Label l848; +Xbyak::Label l8d8; +Xbyak::Label l930; +Xbyak::Label l974; +Xbyak::Label l9b8; +Xbyak::Label l9ec; +Xbyak::Label la0a; +Xbyak::Label la14; +Xbyak::Label la28; +Xbyak::Label la6c; +Xbyak::Label laa8; +Xbyak::Label lae0; +Xbyak::Label lb14; +Xbyak::Label lb38; +Xbyak::Label lb58; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l540, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l22c, T_NEAR); + align(4); + +L(l48); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l48, T_NEAR); + align(4); + +L(l22c); + test(M, 0x8); + jle(l340, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l340); + test(M, 0x4); + jle(l3f8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3f8); + test(M, 0x2); + jle(l498, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l498); + test(M, 0x1); + jle(l51c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l51c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l540); + cmp(N, 0x4); + jl(l81c, T_NEAR); + align(4); + +L(l54c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l664, T_NEAR); + align(4); + +L(l56c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l56c, T_NEAR); + align(4); + +L(l664); + test(M, 0x8); + jle(l6f8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l6f8); + test(M, 0x4); + jle(l75c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l75c); + test(M, 0x2); + jle(l7b4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l7fc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l7fc); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l54c, T_NEAR); + align(4); + +L(l81c); + cmp(N, 0x2); + jl(la0a, T_NEAR); + align(4); + +L(l828); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l8d8, T_NEAR); + align(4); + +L(l848); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l848, T_NEAR); + align(4); + +L(l8d8); + test(M, 0x8); + jle(l930, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l930); + test(M, 0x4); + jle(l974, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l974); + test(M, 0x2); + jle(l9b8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l9b8); + test(M, 0x1); + jle(l9ec, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l9ec); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l828, T_NEAR); + align(4); + +L(la0a); + cmp(N, 0x1); + jl(lb58, T_NEAR); + align(4); + +L(la14); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(la6c, T_NEAR); + align(4); + +L(la28); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(la28, T_NEAR); + align(4); + +L(la6c); + test(M, 0x8); + jle(laa8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(laa8); + test(M, 0x4); + jle(lae0, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lae0); + test(M, 0x2); + jle(lb14, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(lb14); + test(M, 0x1); + jle(lb38, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(lb38); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(la14, T_NEAR); + align(4); + +L(lb58); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp new file mode 100644 index 0000000000..afe4f1713e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp @@ -0,0 +1,647 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l15c; +Xbyak::Label l1f4; +Xbyak::Label l20; +Xbyak::Label l248; +Xbyak::Label l280; +Xbyak::Label l2a4; +Xbyak::Label l2b0; +Xbyak::Label l2c8; +Xbyak::Label l384; +Xbyak::Label l3e8; +Xbyak::Label l40; +Xbyak::Label l424; +Xbyak::Label l448; +Xbyak::Label l468; +Xbyak::Label l474; +Xbyak::Label l48c; +Xbyak::Label l550; +Xbyak::Label l5bc; +Xbyak::Label l600; +Xbyak::Label l628; +Xbyak::Label l646; +Xbyak::Label l650; +Xbyak::Label l668; +Xbyak::Label l700; +Xbyak::Label l760; +Xbyak::Label l7a4; +Xbyak::Label l7c8; +Xbyak::Label l7e8; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l2a4, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(l15c, T_NEAR); + align(4); + +L(l40); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l40, T_NEAR); + align(4); + +L(l15c); + test(M, 0x4); + jle(l1f4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l1f4); + test(M, 0x2); + jle(l248, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l248); + test(M, 0x1); + jle(l280, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l280); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l2a4); + cmp(N, 0x4); + jl(l468, T_NEAR); + align(4); + +L(l2b0); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(l384, T_NEAR); + align(4); + +L(l2c8); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l2c8, T_NEAR); + align(4); + +L(l384); + test(M, 0x4); + jle(l3e8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3e8); + test(M, 0x2); + jle(l424, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l424); + test(M, 0x1); + jle(l448, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l448); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l2b0, T_NEAR); + align(4); + +L(l468); + cmp(N, 0x2); + jl(l646, T_NEAR); + align(4); + +L(l474); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l550, T_NEAR); + align(4); + +L(l48c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l48c, T_NEAR); + align(4); + +L(l550); + test(M, 0x4); + jle(l5bc, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l5bc); + test(M, 0x2); + jle(l600, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l600); + test(M, 0x1); + jle(l628, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l628); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l474, T_NEAR); + align(4); + +L(l646); + cmp(N, 0x1); + jl(l7e8, T_NEAR); + align(4); + +L(l650); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l700, T_NEAR); + align(4); + +L(l668); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l668, T_NEAR); + align(4); + +L(l700); + test(M, 0x4); + jle(l760, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l760); + test(M, 0x2); + jle(l7a4, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l7a4); + test(M, 0x1); + jle(l7c8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l7c8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l650, T_NEAR); + align(4); + +L(l7e8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp new file mode 100644 index 0000000000..4fc11afcbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp @@ -0,0 +1,116 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "../f32/ref_gemm_f32.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; + size_t sizeA = AisN ? lda * k : lda * m; + size_t sizeB = BisN ? ldb * n : ldb * k; + size_t sizeC = ldc * n; + + double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); + double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); + double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); + + if (utils::any_null(dA, dB, dC)) { + free(dA); + free(dB); + free(dC); + return mkldnn_out_of_memory; + } + + auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; }; + auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; }; + + auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; }; + auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; }; + + const int a_rows = AisN ? m : k; + const int a_cols = AisN ? k : m; + mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) { + da_setter(i, j, + static_cast(ia_accessor(i, j)) + static_cast(ao[0])); + }); + + const int b_rows = BisN ? k : n; + const int b_cols = BisN ? n : k; + mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) { + db_setter(i, j, + static_cast(ib_accessor(i, j)) + static_cast(bo[0])); + }); + double one = 1.0, zero = 0.0; + ref_gemm(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, + dC, LDC, nullptr); + + auto i2d = [=] (int32_t v) { return static_cast(v); }; + auto f2d = [=] (float v) { return static_cast(v); }; + + mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) { + double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); + double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) + + f2d(*alpha) * dC[i + j * ldc] + coffset; + C[i + j * ldc] = math::out_round(math::saturate(val)); + }); + + free(dA); + free(dB); + free(dC); + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm_s8x8s32( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +template mkldnn_status_t ref_gemm_s8x8s32( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp new file mode 100644 index 0000000000..6c0370ae99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_S8X8S32_HPP +#define REF_GEMM_S8X8S32_HPP + +#include + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp new file mode 100644 index 0000000000..de1035f3b2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" +#include "nstl.hpp" +#include "math_utils.hpp" + +#include "../gemm.hpp" +#include "jit_avx512_core_gemm_s8u8s32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void compensation_init(const char *offsetC, int32_t *compensation, int len, + const int32_t *oc) { + bool OCisC = (*offsetC == 'C' || *offsetC == 'c'); + bool OCisF = (*offsetC == 'F' || *offsetC == 'f'); + + if (OCisF && (*oc) != 0) { + for (int i = 0; i < len; i++) + compensation[i] = *oc; + } else if (OCisC) { + for (int i = 0; i < len; i++) + compensation[i] = oc[i]; + } else { + parallel_nd(len, [=](int i) { compensation[i] = 0; }); + } +} + +void compensation_compute(bool transa, int m, int k, float alpha, + const int8_t *a, int lda, int32_t *compensation) { + if (!transa) { + const int L2_cache_size = get_cache_size(2, true); + const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1); + const int npanels = k / blocking_factor; + const bool has_tile = k % blocking_factor > 0; + + parallel_nd(npanels, m, [&](int j, int i) { + int32_t val = 0; + for (int jb = 0; jb < blocking_factor; jb++) { + val += a[(i + (ptrdiff_t)j * blocking_factor * lda) + + (ptrdiff_t)jb * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + + if (has_tile) { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = npanels * blocking_factor; j < k; j++) { + val += a[i + (ptrdiff_t)j * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + } + } else { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = 0; j < k; j++) { + val += a[j + (ptrdiff_t)i * lda]; + } + if (alpha != 1.0f) { + val = math::out_round(math::saturate( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + compensation[i] += val; + }); + } +} + +void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8, + const int8_t *b_s8, int ldb_s8) { + const int b_cols = transb ? k : n; + + parallel_nd(b_cols, [=](int j) { + const int b_rows = transb ? n : k; + + uint8_t *pb_u8 = b_u8 + j * ldb_u8; + const int8_t *pb_s8 = b_s8 + j * ldb_s8; + + for (int i = 0; i < b_rows; i++) { + (*pb_u8) = (*pb_s8) + 128; + pb_u8++; + pb_s8++; + } + }); +} + +/** + * gemm_s8s8s32 operation is defined as follows: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation + * + * where + * - compensation is a vector of length m that contains computed compensation + * that may contain C_offset if applicable. The compensation is applied inside + * gemm_s8u8s32 as a C_offset + * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128 + * + * What is the compensation: + * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset = + * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset + * compensation = -alpha * op(A) * B_shift + * Since B_shift is a matrix, every element of which is equal to 128 then + * - if op(A) = A: compensation contains sum of the elements in each row + * scaled by -128 * alpha + * - if op(A) = A**T: compensation contains sum of the elements in each column + * scaled by -128 * alpha + * + * The rest of parameters is described in mkldnn.h + */ +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) { + if (*oa != 0 || *ob != 0) return mkldnn_unimplemented; + + int M = *m, N = *n, K = *k; + bool transa = (*transA == 'T' || *transA == 't'); + bool transb = (*transB == 'T' || *transB == 't'); + int ld = transb ? N : K; + + uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64); + int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64); + + if (utils::any_null(b_u8, compensation)) { + free(b_u8); + free(compensation); + return mkldnn_out_of_memory; + } + + compensation_init(offsetC, compensation, M, oc); + compensation_compute(transa, M, K, *alpha, a, *lda, compensation); + copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb); + + gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8, + &ld, ob, beta, c, ldc, compensation); + + if ((*offsetC == 'R' || *offsetC == 'r')) + parallel_nd(M, N, + [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; }); + + free(b_u8); + free(compensation); + + return mkldnn_success; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp new file mode 100644 index 0000000000..03a3d2f7e0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_GEMM_S8S8S32_HPP +#define SIMPLE_GEMM_S8S8S32_HPP + +#include +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); +} +} +} + +#endif // SIMPLE_GEMM_S8S8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp new file mode 100644 index 0000000000..604a728b47 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp @@ -0,0 +1,307 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "gemm_convolution.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + const int K = jcp.ic * jcp.ks; + const int N = jcp.oc; + + if (jcp.im2col_sz && jcp.id != 1) + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow; + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, + nb_oh, owb, nb_ow); + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const data_t *_src = src + (n * jcp.ngroups + g) * src_step; + const data_t *_weights = weights + g * weights_g_size; + data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, oh, h_step, ow, w_step); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od); + } + + const data_t one = 1.0; + + const int m = h_step * w_step; + const int LDA = jcp.im2col_sz ? m : M; + data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow; + + extended_sgemm("N", "N", &m, &N, &K, &one, + jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K, + &this->beta_, _dst, &M); + + data_t *d = _dst; + if (eltwise_) { + // fast branch for ReLU case + if (eltwise_->alg_ == alg_kind::eltwise_relu) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_; + } + }); + } else { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + d_[oS] = eltwise_->compute_scalar(d_[oS]); + } + }); + } + } else if (jcp.with_bias) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = bias[g * jcp.oc + oc]; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + } + }); + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh, + owb, nb_ow); + } + }); +} + +void gemm_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int m = jcp.os; + const int K = jcp.oc; + const int N = jcp.ic * jcp.ks; + const int LDC = jcp.im2col_sz ? m : M; + + const size_t work_amount = (size_t)jcp.ngroups * jcp.mb; + + if (jcp.id > 1) { + const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step); + parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; }); + } + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{0}, n{0}; + size_t start = 0, end = 0; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); + for (size_t iwork = start; iwork < end; ++iwork) { + + data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step; + const data_t *_weights = weights + g * weights_g_size; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g) + *dst_step + od * m; + + const data_t zero = 0.0, one = 1.0; + extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M, + _weights, &N, &zero, + jcp.im2col_sz ? _col:_diff_src + od * m, &LDC); + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::col2im(jcp, _col, + _diff_src); + else + jit_gemm_convolution_utils::col2im_3d(jcp, _col, + _diff_src, od); + } + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb); + } + }); +} + +void gemm_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto col = scratchpad(ctx).get(key_conv_gemm_col); + auto wei_reduction = scratchpad(ctx).get(key_conv_wei_reduction); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int K = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * K; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int k = jcp.os; + const int N = jcp.oc; + const int M = jcp.ic * jcp.ks; + const int LDA = jcp.im2col_sz ? k : K; + + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + int ithr_g, nthr_g, ithr_mb, nthr_mb; + size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0}; + + const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; + jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, + mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); + + assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); + const int need_reduction = nthr_mb != 1; + + if (ithr_g != -1 && ithr_mb != -1) { + balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); + balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); + + assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); + + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + data_t *weights_reduce_base = wei_reduction + + ithr_g * nthr_mb * weights_g_size; + data_t *weights_reduce = weights_reduce_base + + ithr_mb * weights_g_size; + + for (size_t g = g_start; g < g_end; ++g) { + data_t *_diff_weights = need_reduction + ? weights_reduce : (diff_weights + g * weights_g_size); + for (size_t mb = mb_start; mb < mb_end; ++mb) { + const data_t *_src = src + (mb*jcp.ngroups+g)*src_step; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + + (mb*jcp.ngroups+g)*dst_step + od * k; + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, 0, jcp.oh, 0, jcp.ow); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, + _col, od); + } + + const data_t zero = 0.0, one = 1.0; + extended_sgemm( + "T", "N", &M, &N, &k, &one, + jcp.im2col_sz ? _col : _src + od * k, + &LDA, _diff_dst, &K, + mb == mb_start && od == 0 ? &zero : &one, + _diff_weights, &M); + } + } + } + if (need_reduction) { + mkldnn_thr_barrier(); + data_t *weights_base = diff_weights + g_start * weights_g_size; + jit_gemm_convolution_utils::bwd_weights_reduction_par( + ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base); + } + } else + if (need_reduction) { mkldnn_thr_barrier(); } + }); + + if (jcp.with_bias) { + parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) { + data_t db = 0; + size_t offset_ = (size_t)g * dst_step + (size_t)oc * K; + for (int mb = 0; mb < jcp.mb; ++mb) + { + size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step; + for (int od = 0; od < jcp.od; ++od) + for (int oh = 0; oh < jcp.oh; ++oh) + PRAGMA_OMP_SIMD(reduction(+:db)) + for (int ow = 0; ow < jcp.ow; ++ow) { + db += diff_dst[offset]; + offset++; + } + } + diff_bias[g*jcp.oc+oc] = db; + }); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp new file mode 100644 index 0000000000..302e46369a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp @@ -0,0 +1,250 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP +#define CPU_JIT_GEMM_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "gemm_convolution_utils.hpp" +#include "gemm/gemm.hpp" +#include "ref_eltwise.hpp" + +#include "cpu_convolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct gemm_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + + bool post_ops_ok() const { + auto const &po = attr()->post_ops_; + auto is_eltwise = [&](int idx) + { return po.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; + + switch (po.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + return false; + } + }; + + gemm_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , eltwise_(nullptr) + { + const auto &post_ops = pd()->attr()->post_ops_; + const data_t one = 1.0, zero = 0.0; + beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero; + + const int entry_idx = post_ops.find(primitive_kind::eltwise); + if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t( + post_ops.entry_[entry_idx].eltwise); + } + + ~gemm_convolution_fwd_t() { delete eltwise_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + data_t beta_; + + ref_eltwise_scalar_fwd_t* eltwise_; +}; + +struct gemm_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct gemm_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*diff_weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), diff_weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp new file mode 100644 index 0000000000..f133b1e62b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp @@ -0,0 +1,771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "cpu_isa_traits.hpp" + +#include "gemm_convolution_utils.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; +using namespace prop_kind; +using namespace data_type; + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od) +{ + const size_t OHW = jcp.oh * jcp.ow; + const size_t im_step = jcp.ih * jcp.iw * jcp.id; + const size_t col_step = jcp.ks * OHW; + + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict im_loc = im + ic * im_step; + float *__restrict col_loc = col + ic * col_step; + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; + if (id < 0 || id >= jcp.id) { + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + + col_[col_idx] = 0; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } else { + const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + const size_t im_idx = ih * jcp.iw + iw; + + col_[col_idx] = im_[im_idx]; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } + id += (1 + jcp.dilate_d); + } + }); +} + +/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb) { + const size_t im_step = jcp.is; + const size_t col_step = jcp.ks * hb * wb; + if (jcp.stride_w == 1) { + // Generated code is more optimized for stride_w == 1 + // because innermost loop is by width + auto ker = [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ + = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else { + const size_t im_idx = ih * jcp.iw + iw; + col_[ow] = im_[im_idx]; + } + } + } + }; + + if (jcp.outer_threading) { + for (int ic = 0; ic < jcp.ic; ic++) + for (int kh = 0; kh < jcp.kh; kh++) + for (int kw = 0; kw < jcp.kw; kw++) + for (int oh = 0; oh < hb; oh++) + ker(ic, kh, kw, oh); + } + else { + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker); + } + } else if (jcp.ic == 1) { + parallel_nd(jcp.kh, hb, [&](int kh, int oh) { + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + col[col_idx] = 0; + } + } + else + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col[col_idx] = 0; + else + col[col_idx] = im[im_idx]; + } + } + }); + } else { + + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, + [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ = col + ic * col_step + + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else + col_[ow] = im_[im_idx]; + } + } + }); + } +} + +inline int limit(int low, int upper, int value) { + return nstl::max(low, nstl::min(upper, value)); +} + +/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */ +template +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws, + int wb) { + uint8_t shift = jcp.signed_input ? 128 : 0; + const int dh = 1 + jcp.dilate_h; + const int dw = 1 + jcp.dilate_w; + const int sh = jcp.stride_h; + const int sw = jcp.stride_w; + const int im_iw_stride = jcp.ic * jcp.ngroups; + const int im_ih_stride = jcp.iw * im_iw_stride; + const int tp = jcp.t_pad; + const int lp = jcp.l_pad; + + if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { + /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ + const int hp = hs - tp; + const int wp = ws - lp; + const int ih_start = limit(0, jcp.ih, hp); + const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh); + const int iw_start = limit(0, jcp.iw, wp); + const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw); + + const int ihb = ih_end - ih_start; + const int iwb = iw_end - iw_start; + + const int imtr_ic_stride = ihb * iwb; + const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start; + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift; + for (int ih = ih_start; ih < ih_end; ih++) { + const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride; + const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb; + for (int iw = iw_start; iw < iw_end; iw++) + imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride]; + } + } + + const int col_ic_str = hb * wb; + const int col_kw_stride = jcp.ic * col_ic_str; + const int col_kh_stride = jcp.kw * col_kw_stride; + + const int oh_init = ih_start - hp; + const int ow_init = iw_start - wp; + for (int kh = 0; kh < jcp.kh; kh++) { + const ptrdiff_t col_idx_kh = kh * col_kh_stride; + const int oh_kh = oh_init - kh; + const int oh_start = limit(0, hb, oh_kh); + const int oh_end = limit(0, hb, oh_kh + ihb); + for (int kw = 0; kw < jcp.kw; kw++) { + const ptrdiff_t col_idx_kw + = col_idx_kh + kw * jcp.ic * col_ic_str; + const int ow_kw = ow_init - kw; + const int imtr_shift = oh_kh * iwb + ow_kw; + const int ow_start = limit(0, wb, ow_kw); + const int ow_end = limit(0, wb, ow_kw + iwb); + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; + const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; + for (int oh = 0; oh < oh_start; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_start; oh < oh_end; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; + for (int ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = shift; + for (int ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow] + shift; + for (int ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_end; oh < hb; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + } + } + } + } else { + parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, + [&](int kh, int kw, int ic, int oh) { + const int hp = tp - kh * dh; + const int ih = (oh + hs) * sh - hp; + const ptrdiff_t col_idx_base + = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; + if (ih < 0 || ih >= jcp.ih) + for (int ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = shift; + else { + const int wp = lp - kw * dw; + const int ow_start = limit(0, wb, div_up(wp, sw) - ws); + const int ow_end + = limit(0, wb, div_up(jcp.iw + wp, sw) - ws); + for (int ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = shift; + const int iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (int ow = ow_start; ow < ow_end; ow++) { + const int iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx] + shift; + } + for (int ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = shift; + } + }); + } +} + +template void im2col_u8(const jit_gemm_conv_conf_t &jcp, + const int8_t *__restrict im, int8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); +template void im2col_u8(const jit_gemm_conv_conf_t &jcp, + const uint8_t *__restrict im, uint8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); + +/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */ +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im) +{ + parallel(0, [&](const int ithr, const int nthr) { + int h_nthr = nstl::min(jcp.ih, nthr); + int w_nthr = nstl::min(jcp.iw, nthr / h_nthr); + int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0; + if (ithr < h_nthr * w_nthr) { + h_ithr = ithr / w_nthr; + w_ithr = ithr % w_nthr; + balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e); + balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e); + } else { + h_ithr = w_ithr = -ithr; + h_s = h_e = w_s = w_e = -1; + } + + for (int ih = h_s; ih < h_e; ++ih) { + for (int iw = w_s; iw < w_e; ++iw) { + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0; + } + } + } + + // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh] + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h + - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < h_s || ih >= h_e) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w + - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < w_s || iw >= w_e) continue; + + const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh + + kh) * jcp.kw + kw) * jcp.ic; + const size_t im_idx + = (ih * jcp.iw + iw) * jcp.ic; + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[im_idx + ic] += col[col_idx + ic]; + } + } + } + } + } + }); +} + +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od) +{ + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os; + float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id; + + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + if (id < 0 || id >= jcp.id) { + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + continue; + } + + float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw; + + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + }} + }} + + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + } + }); +} + +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) { + const size_t col_step = jcp.ks * jcp.os; + const size_t im_step = jcp.ih * jcp.iw; + const int iS = jcp.ih * jcp.iw; + + parallel_nd(jcp.ic, [&](int ic) { + float *__restrict im_ = im + ic * im_step; + const float *__restrict col_ = col + ic * col_step; + PRAGMA_OMP_SIMD() + for (int is = 0; is < iS; ++is) im_[is] = 0.; + + for (int kh = 0; kh < jcp.kh; ++kh) { + for (int oh = 0; oh < jcp.oh; ++oh) { + const int ih = + oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < jcp.ow; ++ow) { + const int iw = + ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + } + } + } + } + }); +} + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads) { + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + const int is_1d = ndims == 3; + const int is_3d = ndims == 5; + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = is_3d ? src_d.dims()[2] : 1; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = is_3d ? dst_d.dims()[2] : 1; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = is_3d ? cd.padding[0][0] : 0; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_d = is_3d ? cd.strides[0] : 1; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_d = is_3d ? cd.dilates[0] : 0; + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef + || cd.diff_bias_desc.format_kind != format_kind::undef; + + jcp.is = jcp.ih * jcp.iw; + jcp.os = jcp.oh * jcp.ow; + jcp.ks = jcp.kh * jcp.kw * jcp.kd; + + jcp.signed_input = src_d.data_type() == data_type::s8; + + jcp.im2col_sz = !everyone_is(true, + jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id, + jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1, + jcp.ks == 1, !jcp.signed_input) + ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0; + + jcp.outer_threading = false; + + bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8) + && weights_d.data_type() == s8; + + const int vlen = mayiuse(avx512_common) + ? cpu_isa_traits::vlen + : mayiuse(avx) + ? cpu_isa_traits::vlen + : mayiuse(sse42) ? cpu_isa_traits::vlen : 4; + const int simd_w = vlen / (is_int8_conv ? 1 : 4); + + const bool is_bwd_d = jcp.prop_kind == backward_data; + const bool is_bwd_w = jcp.prop_kind == backward_weights; + const bool is_fwd = !is_bwd_d && !is_bwd_w; + jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; + jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; + + using namespace memory_tracking::names; + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + + // TODO: maybe mitigate blocking restriction + const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int L2 = get_cache_size(2, true) + / (is_int8_conv ? sizeof(int8_t) : sizeof(float)); + bool is_blocking_applicable = true + && is_fwd && jcp.im2col_sz + && jcp.id == 1 && jcp.od == 1 + && jcp.dilate_h == 0 && jcp.dilate_w == 0 + && !is_depthwise + && wei_size < L2/2; + if (is_blocking_applicable) { + // looking for oh and ow blocking + int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; + const int ic = jcp.ic; + const int oc = jcp.oc; + const int iw = jcp.iw; + const int ow = jcp.ow; + const int oh = jcp.oh; + const int os = oh * ow; + + // 1. cache requirement + int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); + if (is_int8_conv) { + // Heuristic rule: gemm needed a lot of memory for internal usage + row_size *= 5; + // memory for accumulators + row_size += oc * ow * sizeof(uint32_t); + // memory for transposition + row_size += ic * iw; + } + + h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); + if (h_block == 1) { + int col_size = ic * jcp.ks + 2 * (ic + oc); + if (is_int8_conv) { + col_size *= 5; + col_size += oc * sizeof(uint32_t); + col_size += ic; + } + w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); + } + + // 2. threading requirement + if (h_block != oh) + h_block = nstl::max(1, rnd_dn(h_block, 4)); + if (w_block != ow) + w_block = nstl::max(1, rnd_dn(w_block, simd_w)); + + float thr_eff = 0.f; + float thr_eff_treshold = 0.9f; + if (w_block == ow) { + do { + int nb_h = div_up(oh, h_block); + size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; + float disb = (float)oh / rnd_up(oh, h_block); + thr_eff = (float)work / rnd_up(work, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff >= thr_eff_treshold) + break; + h_block = rnd_dn(h_block - 4, 4); + } while (h_block > 0); + } + if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block + { + h_block = 1; + int nb_h = oh; + do { + int nb_w = div_up(ow, w_block); + size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; + float disb = (float)ow / rnd_up(ow, w_block); + thr_eff = (float)work_amount / rnd_up(work_amount, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff > thr_eff_treshold) + break; + w_block = rnd_dn(w_block - simd_w, simd_w); + } while (w_block > 0); + } + h_block = nstl::max(1, h_block); + w_block = nstl::max(1, w_block); + const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) { + jcp.oh_block = h_block; + jcp.ow_block = w_block; + jcp.outer_threading = true; + } + // updating jcp.im2col_sz + if (jcp.oh_block != 1) + jcp.ow_block = ow; + jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; + } + // For threading selection in bwd_d we do: + // 1. Rough estimation of efficiency for inner and outer threading. + // 2. Gemm size estimation in assumption that it does not work + // so effectively for small sizes. + // 64K - this is heuristic gemm size per thread threshold. + const int gemm_thrld = 64 * 1024; + + if (is_int8_conv) { + if (is_fwd) { + if (!jcp.outer_threading) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); + scratchpad.book(key_conv_gemm_imtr, + sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_d) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_w) { + assert(!"unimplemented prop_kind"); + return status::unimplemented; + } + } else { + if (is_fwd) { + if (!jcp.outer_threading) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work_amount + = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w); + const float inner_thr_eff = (float)inner_work_amount + / rnd_up(inner_work_amount, max_threads); + jcp.outer_threading = jcp.os / max_threads < 512 + && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + } else if (is_bwd_d) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff = (float)inner_work + / rnd_up(inner_work, max_threads); + jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) + && (jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } else if (is_bwd_w) + jcp.outer_threading = jcp.os / max_threads < 256 + && (jcp.mb != 1 || jcp.ngroups > 2); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(float) * jcp.nthr * jcp.im2col_sz); + + if (is_bwd_w) { + jcp.need_wei_reduction = mkldnn_thr_syncable() + ? jcp.mb != 1 && jcp.nthr != 1 : false; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size()); + } + } + + return status::success; +} + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, + int &nthr_g, int &ithr_mb, int &nthr_mb) { + nthr_g = nstl::min(ngroups, nthr); + nthr_mb = nstl::min(mb, nthr / nthr_g); + if (ithr / nthr_mb >= ngroups) { + ithr_g = ithr_mb = -1; + } else { + ithr_g = ithr / nthr_mb; + ithr_mb = ithr % nthr_mb; + } +} + +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights) { + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + size_t weights_start{0}, weights_end{0}; + balance211(weights_g_size, nthr, ithr, weights_start, weights_end); + + for (int i = 0; i < nthr; ++i) { + const float *ws_i = weights_reduce_ws + i * weights_g_size; + for (size_t s = weights_start; s < weights_end; ++s) + weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s]; + } +} + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp new file mode 100644 index 0000000000..e006789344 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP +#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_engine.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od); +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb); +template +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T* __restrict imtr, uint8_t *__restrict col, + int hs, int hb, int ws, int wb); + +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im); +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od); +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im); + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads); + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, + int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb); +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights); + +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp new file mode 100644 index 0000000000..2872122f0d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp @@ -0,0 +1,156 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" + +#include "gemm_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::primitive_kind; + +template +void gemm_inner_product_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = !memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights, + wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias); + + if (do_relu) { + float nslope = post_ops.entry_[0].eltwise.alpha; + parallel_nd(MB, OC, [&](int mb, int oc) { + size_t dst_off = mb * OC + oc; + if (dst[dst_off] < 0) + dst[dst_off] *= nslope; + }); + } +} + +template +void gemm_inner_product_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights, + wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC); +} + +template +void gemm_inner_product_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + diff_dst += diff_dst_d.offset0(); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->diff_weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + if (wei_tr) + extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC, + &beta, diff_weights, &OC); + else + extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC, + &beta, diff_weights, &IC); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + constexpr int blksize = 8; + const int OC_blocks = OC / blksize; + const int rem_OC = OC % blksize; + parallel(0, [&](const int ithr, const int nthr) { + int oc_st{0}, oc_e{0}; + balance211(OC_blocks, nthr, ithr, oc_st, oc_e); + oc_st = oc_st * blksize; + oc_e = oc_e * blksize; + + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] = diff_dst[oc]; + } + + for (int mb = 1; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + + if (rem_OC != 0 && ithr == nthr-1) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) + diff_bias[oc] = diff_dst[oc]; + for (int mb = 1; mb < MB; ++mb) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + } + }); + } +} + +template struct gemm_inner_product_fwd_t; +template struct gemm_inner_product_bwd_data_t; +template struct gemm_inner_product_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp new file mode 100644 index 0000000000..acf0a49b9a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp @@ -0,0 +1,157 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_GEMM_INNER_PRODUCT_HPP +#define CPU_GEMM_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct gemm_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type, + src_md()->data_type, + weights_md()->data_type, + dst_md()->data_type, + with_bias() ? weights_md(1)->data_type : data_type) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct gemm_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_src_md()->data_type, + weights_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(diff_src_md(), weights_md(), + diff_dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + src_md()->data_type, + diff_weights_md()->data_type, + diff_dst_md()->data_type, + with_bias() ? diff_weights_md(1)->data_type : data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(src_md(), diff_weights_md(), + diff_dst_md()); + + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp new file mode 100644 index 0000000000..fed7e4d693 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -0,0 +1,740 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "math_utils.hpp" + +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace mkldnn::impl::memory_tracking::names; + +template +void _gemm_x8s8s32x_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, + scratchpad); + }); +} + +template +_gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::pp_ker_t( + const pd_t *pd) + : ker_(nullptr) + , jcp_(pd->jcp_) + , OC_(pd->jcp_.oc) + , OS_(pd->jcp_.os) + , bias_data_type_(data_type::undef) + , bias_data_type_size_(0) + , scale_idx_mult_(0) + , do_bias_(false) + , do_relu_(false) + , do_sum_(false) +{ + using namespace types; + + const auto dst_md = memory_desc_wrapper(pd->dst_md()); + dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1); + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + + int entry_idx = -1; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + entry_idx = idx; + break; + } + } + do_relu_ = entry_idx >= 0; + + do_signed_scaling_ = jcp_.signed_input; + + do_sum_ = post_ops.contain(primitive_kind::sum, 0); + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + const size_t vlen_start + = cpu_isa_traits::vlen / sizeof(float); + + for (size_t i = vlen_start; i > 0; i--) { + if (OC_ % i == 0) { + vlen_ = i; + break; + } + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs + return; + else + generate(); +} + +template +void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask_short = r10; + Reg64 reg_rem_mask_vlen = r11; + Opmask kreg_rem_mask_short = k1; + Opmask kreg_rem_mask_vlen = k3; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = vlen_; + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + Zmm vreg_sum_scale = Zmm(3); + Zmm vreg_signed_scale = Zmm(4); + + size_t def_unroll = 4; + size_t max_unroll = 12; + size_t zmm_step = 2; + if (do_sum_) { + max_unroll = 8; + zmm_step = 3; + } + + auto vreg_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 0); + }; + auto vreg_bias = [&](int idx) { + return Zmm(5 + idx * zmm_step + 1); + }; + auto vreg_prev_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 2); + }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); + vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); + +#undef PARAM_OFF + + mov(reg_rem_mask_vlen, 1); + shl(reg_rem_mask_vlen, vlen); + sub(reg_rem_mask_vlen, 1); + kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen); + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply sum (if any), + // bias (if any), scaling, and relu (if any); + // then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask_short; + else + vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen; + vmovups(vreg_scale_, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask_short; + else + vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_signed_scaling_) + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask_short; + else + vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + + if (do_sum_) + { + auto vreg_prev_dst_ = vreg_prev_dst(idx); + if (apply_mask) + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short; + else + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen; + + switch (dst_type) { + case data_type::f32: + case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break; + case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break; + case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break; + default: assert(!"unsupported data type"); + } + if (dst_type != data_type::f32) + vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx)); + + vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); + } + + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + if (dst_type == data_type::u8) + vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indexed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t)); + }; + + // <--------- OC ---------------> + // + // ^ ................+..............+-------------+....................... + // | . : not accessed |Prologue loop| . + // | . +--------------+-------------+ . + // . | | . + // O . | Main loop (unrolled) | . + // S . | | . + // . +--------------+-------------+ . + // | . | Epilogue loop|not accessed : . + // v ................+--------------+.............+....................... + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask_short, 1); + // cl == reg_tmp because reg_tmp <= vlen here + shl(reg_rem_mask_short, cl); + sub(reg_rem_mask_short, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } + else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask_short, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask_short, 1); + shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen + sub(reg_rem_mask_short, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode(); +} + +template +void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::operator () + (dst_data_t *dst, const acc_data_t *acc, const char *bias, + const float *scales, float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + size_t os_offset = start / OC_; + args.acc = acc + start; + args.dst = dst + os_offset * dst_os_stride_ + oc_offset; + args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); + args.nslope = nslope; + args.sum_scale = sum_scale; + args.signed_scale = signed_scale; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } + else { + // Fallback + const size_t first_oc = start % OC_; + const size_t last_oc = (end - 1) % OC_; + const size_t first_os = start / OC_; + const size_t last_os = (end - 1) / OC_; + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * dst_os_stride_ + oc; + + float d = (float)(acc[acc_off]); + if (jcp_.signed_input) + d *= signed_scale; + + if (do_bias_) + d += get_bias(bias, g * jcp_.oc + oc, + bias_data_type_); + + d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; + if (do_sum_) + d += sum_scale * dst[dst_off]; + if (do_relu_ && d < 0) + d *= nslope; + dst[dst_off] = qz_a1b0()(d); + } + } + } +}; + +template +void _gemm_x8s8s32x_convolution_fwd_t:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base, + const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const { + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto src_md = memory_desc_wrapper(pd()->src_md()); + const size_t src_mb_stride = src_md.blk_off(1); + const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto dst_md = memory_desc_wrapper(pd()->dst_md()); + const size_t dst_mb_stride = dst_md.blk_off(1); + const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_sum = post_ops.contain(primitive_kind::sum, 0); + const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; + + float nslope = 0; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + nslope = e.eltwise.alpha; + break; + } + } + + auto col = scratchpad.get(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + src_data_t *__restrict imtr = scratchpad.get(key_conv_gemm_imtr) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; + + const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc; + const int32_t *_wei_comp = (const int32_t *)(wei_base + offset); + + int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, + nb_oh, owb, nb_ow); + + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const src_data_t *__restrict src = src_base + n * src_mb_stride + + g * src_g_stride; + const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; + dst_data_t *__restrict dst = + dst_base + n * dst_mb_stride + g * dst_g_stride; + const int32_t *wei_comp = _wei_comp + g * jcp.oc; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::im2col_u8( + jcp, src, imtr, col, oh, h_step, ow, w_step); + + const int M = jcp.oc; + const int K = jcp.ks * jcp.ic; + const int N = h_step * w_step; + const int LDA = M * jcp.ngroups; + const int LDB = jcp.im2col_sz ? N : K; + const char *BT = jcp.im2col_sz ? "T" : "N"; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", + &M, &N, &K, &onef, wei, &LDA, &off_a, + jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b, + &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); + + auto wei_adj_scale = + (wei_md.extra().flags | memory_extra_flags::scale_adjust) + ? wei_md.extra().scale_adjust : 1.f; + + parallel(0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)N * jcp.oc, nthr, ithr, start, end); + (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_, + acc, bia_base, scales, nslope, sum_scale, + 1.f / wei_adj_scale, g, start, end); + }); + + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, + owb, nb_ow); + } +} + +template +void _gemm_u8s8s32x_convolution_bwd_data_t:: +execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base, + bia_base, diff_src_base, scratchpad); + }); +} + +template +void _gemm_u8s8s32x_convolution_bwd_data_t:: +execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const +{ + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); + const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1); + const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md()); + const size_t diff_src_mb_stride = diff_src_md.blk_off(1); + const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic; + const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1); + + /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ + const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); + const float *scales = pd()->attr()->output_scales_.scales_; + const size_t work_amount = jcp.ngroups * jcp.mb; + + auto col = scratchpad.get(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + + int n{0}, g{0}; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + + for (size_t iwork = start; iwork < end; ++iwork) { + const diff_dst_data_t *diff_dst = diff_dst_base + + n * diff_dst_mb_stride + g * diff_dst_g_stride; + const wei_data_t *wei = wei_base + g * wei_g_stride; + diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride + + g * diff_src_g_stride; + + const int M = jcp.ks * jcp.ic; + const int N = jcp.os; + const int K = jcp.oc; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + const int LD = K * jcp.ngroups; + + gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, + wei, &LD, &off_a, diff_dst, &LD, &off_b, + &zerof, jcp.im2col_sz ? col : acc, &M, &off_c); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::col2im_s32(jcp, col, acc); + + parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) { + float d = (float)acc[is * jcp.ic + ic]; + if (jcp.with_bias) + d += get_bias(bia_base, g * jcp.ic + ic, + pd()->desc()->bias_desc.data_type); + d *= scales[(g * jcp.ic + ic) * scale_idx_mult]; + const size_t diff_src_off = is * diff_src_os_stride + ic; + diff_src[diff_src_off] = + qz_a1b0()(d); + }); + nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + } +} + +using namespace data_type; + +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; + +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; +template struct _gemm_x8s8s32x_convolution_fwd_t; + +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +template struct _gemm_u8s8s32x_convolution_bwd_data_t; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp new file mode 100644 index 0000000000..9e77b890d5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP +#define GEMM_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" +#include "gemm_convolution_utils.hpp" + +#include "gemm/gemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_x8s8s32x_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, s8, data_type::undef, dst_type, + s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common( + dat_tag(), format_tag::any, dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio); + + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + + bool post_ops_ok() const { + using namespace mkldnn::impl::primitive_kind; + auto const &po = attr()->post_ops_; + auto is_relu = [&](int idx) { + return po.entry_[idx].is_relu(true, false); }; + + switch (po.len_) { + case 0: return true; + case 1: return is_relu(0) || po.contain(sum, 0); + case 2: return po.contain(sum, 0) && is_relu(1); + default: return false; + } + return false; + } + }; + + _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true), pp_ker_(nullptr) + { pp_ker_ = new pp_ker_t(pd()); } + ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_ker_t : jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _gemm_x8s8s32x_convolution_fwd_t::pp_kernel); + pp_ker_t(const pd_t *pd); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, + float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end); + + size_t dst_os_stride_; + + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + float sum_scale; + float signed_scale; + size_t len; + size_t oc_offset; + }; + void(*ker_)(const ker_args *args); + + const jit_gemm_conv_conf_t &jcp_; + size_t OC_; + size_t OS_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + bool do_sum_; + bool do_signed_scaling_; + size_t vlen_; + }; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src_base, const wei_data_t *wei_base, + const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const; + + int nthr_; + pp_ker_t *pp_ker_; + +}; + +template +struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t{ + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_u8s8s32x_convolution_bwd_data_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(dst_type, s8, data_type::undef, u8, s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && attr()->post_ops_.has_default_values() + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(), diff_dst_md(), + mkldnn_get_max_threads()); + } + + virtual bool support_bias() const override { return true; } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + format_tag_t wei_tag() const { + return with_groups() ? format_tag::hwigo : format_tag::hwio; + } + }; + + _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + void execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp new file mode 100644 index 0000000000..1e435a233a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -0,0 +1,453 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace math; +using namespace format_tag; +using namespace memory_tracking::names; + +template +gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::pp_kernel_t( + const pd_t *pd, bool dst_is_acc) + : ker_(nullptr), OC_(pd->OC()) + , bias_data_type_(data_type::undef), bias_data_type_size_(0) + , scale_idx_mult_(0), do_bias_(false), do_relu_(false) +{ + using namespace types; + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + do_relu_ = post_ops.len_ == 1; + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs since they do not have optimized + // x8s8s32 GEMM anyways. The configuration variables above are used by + // the fallback code. + return; + else + generate(); +} + +template +void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask = r10; + Opmask kreg_rem_mask = k1; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = cpu_isa_traits::vlen / sizeof(float); + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + + auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); }; + auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); +#undef PARAM_OFF + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply bias (if any), scaling, + // and relu (if any); then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask; + vmovups(vreg_scale, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type == data_type::u8) + vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indixed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + }; + + // <-------------------- OC -------------------------------> + // + // ^ +....................+----------------------------------+ + // | : not accessed | Prologue loop | + // | +--------------------+----------------------------------+ + // | | + // M | Main loop (unrolled) | + // B | | + // +--------------------------------+----------------------+ + // | | Epilogue loop | not accessed : + // v +--------------------------------+......................+ + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?) + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here + sub(reg_rem_mask, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t def_unroll = 4; + size_t max_unroll = 13; + + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?) + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16 + sub(reg_rem_mask, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode(); +} + +template +void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::operator ()( + dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + args.dst = dst + start; + args.acc = acc + start; + args.bias = bias + oc_offset * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * oc_offset; + args.nslope = nslope; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } else { + // Fallback + size_t oc = start % OC_; + for (size_t i = start; i < end; i++) { + float d = (float)acc[i]; + float b = get_bias(bias, oc, bias_data_type_); + d = d + b; + d *= scales[oc * scale_idx_mult_]; + if (do_relu_ && d < 0) + d *= nslope; + dst[i] = qz_a1b0()(d); + oc = (oc == OC_ - 1) ? 0 : oc + 1; + } + } +}; + +template +void gemm_x8s8s32x_inner_product_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), oiw, oihw, oidhw, oi); + + const int M = OC; + const int N = MB; + const int K = pd()->IC_total_padded(); + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + acc_data_t *acc = pd()->dst_is_acc_ + ? (acc_data_t *)dst + : scratchpad(ctx).template get(key_iprod_int_dat_in_acc_dt); + + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights, + wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c); + + if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_ + || pd()->with_bias()) { + const bool force_sequential = MB * OC < 2000; + parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)OC * MB, nthr, ithr, start, end); + (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end); + }); + } +} + +using namespace data_type; + +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; +template struct gemm_x8s8s32x_inner_product_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp new file mode 100644 index 0000000000..ac6a5c8f85 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp @@ -0,0 +1,166 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP +#define GEMM_X8S8S32X_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" +#include "jit_generator.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(src_type == data_type::u8 + ? IGEMM_S8U8S32_IMPL_STR + : IGEMM_S8S8S32_IMPL_STR, + gemm_x8s8s32x_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == src_type + && dst_md()->data_type == dst_type + && weights_md()->data_type == s8 + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + if (!ok) return status::unimplemented; + + dst_is_acc_ = utils::one_of(dst_type, s32, f32); + + init_scratchpad(); + + return status::success; + } + + bool dst_is_acc_; + + protected: + status_t set_default_params() { + using namespace format_tag; + if (src_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md_, + utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc))); + } + if (dst_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md_, nc)); + if (weights_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md_, + utils::pick(ndims() - 2, io, wio, hwio, dhwio))); + } + return inner_product_fwd_pd_t::set_default_params(); + } + + private: + void init_scratchpad() { + if (!dst_is_acc_) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book( + memory_tracking::names::key_iprod_int_dat_in_acc_dt, + sizeof(acc_data_t) * MB() * OC()); + } + } + }; + + gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); } + ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; } + + typedef typename prec_traits::type data_t; + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_kernel_t: jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + gemm_x8s8s32x_inner_product_fwd_t::pp_kernel); + pp_kernel_t(const pd_t *pd, bool dst_is_acc); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end); + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + size_t len; + size_t oc_offset; + }; + void (*ker_)(const ker_args *args); + + size_t OC_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + }; + + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + pp_kernel_t *pp_kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp new file mode 100644 index 0000000000..6fa251d465 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp @@ -0,0 +1,674 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop, bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto vreg_load = [=](int i) { + return Ymm(ur * load_loop_blk + i); + }; + + auto vreg_accum = [=](int i, int j) { + return Ymm(j * load_loop_blk + i); + }; + + auto bias_ptr = [=](int i) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) + { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt]; + }; + + auto output_ptr = [=](int i, int j) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done, init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) + vmovups(vreg_accum(i, j), bias_ptr(i)); + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r = vreg_accum(i, j); + vxorps(r, r, r); + } + + L(init_done); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(vreg_load(i), load_ptr(0, i)); + vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); + }; + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r = vreg_accum(i, j); + vaddps(r, r, output_ptr(i, j)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + vmovups(output_ptr(i, j), vreg_accum(i, j)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + if (mayiuse(avx2)) + vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(vtmp, vreg_bcast, vreg_load(i)); + vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp); + } + if (j == ur - 1 && !(last_block + && u == jcp.reduce_loop_unroll - 1)) + vmovups(vreg_load(i), load_ptr(u + 1, i)); + } + if (j < ur - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); + } + if (!last_block || u < jcp.reduce_loop_unroll - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); + } + }; + + Label reduce_loop, reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i) { return Ymm(i); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r = diff_bias_reg(i); + vxorps(r, r, r); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(diff_bias_reg(i), diff_bias_ptr(i)); + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) + vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) + vmovups(diff_bias_ptr(i), diff_bias_reg(i)); + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_avx2_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, 8); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + + const int simd_w = 8; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.ih == jcp.oh && jcp.iw == jcp.ow + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + // TODO: remove this restriction + // optimized 1x1 bwd_w does not support Intel AVX + if (jcp.prop_kind == backward_weights && !mayiuse(avx2)) + return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx2_1x1_conv_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp new file mode 100644 index 0000000000..bfdeb2b18d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp @@ -0,0 +1,110 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP +#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_conv_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32) + + jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_avx2_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using ymm_t = const Xbyak::Ymm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t imm_addr64 = reduce_loop_iter; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + ymm_t vreg_bcast = ymm_t(15); + ymm_t vtmp = ymm_t(14); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp new file mode 100644 index 0000000000..f116ac9056 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp @@ -0,0 +1,545 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx2_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +/* convolution forward */ + +void jit_avx2_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int ndims = dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step = step(jcp.nb_load_blocking, + jcp.nb_load - ocb, jcp.nb_load_blocking_max); + + const int _ocb = g * nb_oc + ocb; + p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + + p.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + nb_ic_blocking * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + + if (ocb == 0) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +/* convolution backward wtr data */ + +void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + const int ndims = diff_dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int load_step = 0; + for (int icb = 0; icb < jcp.nb_load; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = start; iwork < end; iwork += bcast_step) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb = 0; ocb < jcp.nb_reduce; + ocb += jcp.nb_reduce_blocking) { + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, + ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + nb_oc_blocking * jcp.oc_block); + + kernel_->jit_ker(&p); + } + + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + }; + + parallel(0, ker); +} + +/* convolution backward wtr weights */ + +jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t( + const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) +{ + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + reducer_weights_ = + new cpu_reducer_2d_t(pd()->reducer_wei_conf_); + reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); + init_rtus_driver(this); +} + +void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + const int ndims = diff_dst_d.ndims(); + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + const int bcast_work = div_up(nb_ic, nb_ic_blocking); + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + const int load_work = div_up(nb_oc, nb_oc_blocking); + + const int sp_dim = jcp.reduce_dim; + const int mb_sp_work = jcp.mb * sp_dim; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image, + data_t *store_to, size_t store_to_ld, const data_t *diff_dst, + const data_t *src, int ithr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + p.output_stride = store_to_ld * sizeof(float); + const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block; + + int oc_b_step = 0; + for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) { + oc_b_step = step(12, nb_oc_blocking - oc_b, 18); + p.load_dim = oc_b_step * jcp.oc_block; + + int ic_b_step = 0; + for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) { + ic_b_step = step(12, nb_ic_blocking - ic_b, 18); + p.bcast_dim = ic_b_step * jcp.ic_block; + rp.icb = p.bcast_dim / jcp.ic_block; + + p.output_data = store_to + oc_b * store_to_ld + + ic_b * jcp.ic_block * jcp.oc_block; + + /* spatial reduction */ + int sp_step = 0; + for (int sp = sp_start; sp < sp_end; sp += sp_step) { + sp_step = step(sp_step_def, sp_end - sp, 192); + p.reduce_dim = sp_step; + rp.os = p.reduce_dim; + + p.first_last_flag = sp == sp_start && first_image + ? FLAG_REDUCE_FIRST : 0; + + p.load_data = diff_dst + + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + (ic_b * jcp.is + sp) * jcp.ic_block; + if (ndims == 3) + rp.src = src + + iw * src_d.blocking_desc().strides[2]; + else + rp.src = src + + ih * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + + if (oc_b == 0) + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_njobs = rw->balancer().ithr_njobs(ithr); + if (w_njobs == 0) return; + + /* setup: independent work (oc, ic) */ + const int w_job_start = rw->balancer().ithr_job_off(ithr); + int g{0}, load_i{0}, bcast_i{0}; + nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work, + bcast_i, bcast_work); + + /* setup: reduction work (mb, sp) */ + int mb_sp_start{0}, mb_sp_end{0}; + balance211(mb_sp_work, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end); + int img_start{0}, sp_start{0}; + nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim); + + /* independent work */ + for (int iwork = 0; iwork < w_njobs; ++iwork) { + const int oc_b = nb_oc_blocking * load_i; + const int ic_b = nb_ic_blocking * bcast_i; + + const int _ic_b = g * nb_ic + ic_b; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + size_t store_to_ld; + + if (rw->balancer().nthr_per_group_ == 1) { + const size_t off = pd()->with_groups() + ? diff_weights_d.blk_off(g, oc_b, ic_b) + : diff_weights_d.blk_off(oc_b, ic_b); + store_to = &diff_weights[off]; + store_to_ld = jcp.ic * jcp.oc_block; + } else { + const size_t off = iwork * rw->balancer().job_size_; + store_to = + rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off; + store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block; + } + + /* reduction work */ + int img = img_start; + int sp = sp_start; + int sp_step = 0; + for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) + { + sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp); + + const bool first_image = img == img_start; + oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to, + store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)], + &src[src_d.blk_off(img, _ic_b)], ithr); + + sp = 0; + img += 1; + } + + nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i, + bcast_work); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = + rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp new file mode 100644 index 0000000000..9762242173 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t { + // TODO: (Roma) Code duplication duplication! Remove with templates + // (maybe...)! + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx2_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx2_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + cpu_reducer_2d_t::conf_t reducer_wei_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const int ic_block = jcp_.bcast_block; + const int nb_ic = jcp_.nb_bcast; + const int nb_ic_blocking = jcp_.nb_bcast_blocking; + const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); + + const int oc_block = jcp_.load_block; + const int nb_oc = jcp_.nb_load; + const int nb_oc_blocking = jcp_.nb_load_blocking; + const int load_work = utils::div_up(nb_oc, nb_oc_blocking); + + const int job_size + = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; + const int njobs_x = bcast_work; + const int njobs_y = jcp_.ngroups * load_work; + + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = max_threads * job_size * 8; + + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + oc_block, jcp_.ngroups * jcp_.oc / oc_block, + jcp_.mb, max_buffer_size)); + } + + reducer_wei_conf_.init( + reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, + jcp_.mb * jcp_.nb_reduce, max_buffer_size), + job_size / nb_oc_blocking, nb_oc_blocking, ic_block, + nb_ic * ic_block * oc_block, nb_oc); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx2_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete rtus_driver_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + cpu_reducer_2d_t *reducer_weights_; + cpu_reducer_t *reducer_bias_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp new file mode 100644 index 0000000000..e24770a2da --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp @@ -0,0 +1,1501 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw + + (ki*dilate_w + jj*stride_w - pad_l)); + else + inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w + - pad_l)*ic_blk + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + } +} + +void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw + + (jj * stride_w - pad_l)); + else + inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk + + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = + ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel AVX support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int od = jcp.od; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : ic_blk; + const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w; + + Label init_done, init_first; + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + vmovups(Ymm(ur_w * ii + jj), + make_safe_addr(reg_output, offt, reg_long_offt)); + } + } + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vmovups(Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj)); + } + + L(init_done); + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + Label skip_kh_loop, skip_kd_loop, kd_loop; + if (jcp.ndims == 5) { + push(reg_output); + push(oi_iter); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_input); + + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + cmp(reg_ki, 0); + je(skip_kd_loop, T_NEAR); + } + L(kd_loop); + mov(kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_input, aux_reg_inp_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks, + oc_blocks_tag); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult); + add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + pop(reg_output); + } + + Label regular_store; + + if (jcp.with_eltwise) { + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off + = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + Ymm reg_out = Ymm(ur_w * ii + jj); + vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out); + } + } +} + +inline void jit_avx2_conv_fwd_kernel_f32::solve_common( + int oc_blocks, char oc_blocks_tag) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, + 'l', oc_blocks, oc_blocks_tag); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, + 'l', oc_blocks, oc_blocks_tag); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, + 'm', oc_blocks, oc_blocks_tag); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, + 'r', oc_blocks, oc_blocks_tag); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, + 't', oc_blocks, oc_blocks_tag); // "tail" +} + +void jit_avx2_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + if (jcp.nb_oc > jcp.nb_oc_blocking) { + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + L(exit); + } else if (jcp.nb_oc == jcp.nb_oc_blocking) { + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + } else { + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int simd_w = 8; + const bool flat = jcp.ic < simd_w; + const bool mimo = !flat; + + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off = + one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c); + if (!args_ok) return status::unimplemented; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively + // Thus, we can only assign 14 or 15 YMMs for data storage + const int num_avail_regs = mayiuse(avx2) ? 15 : 14; + if (!mayiuse(avx2)) { + if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { + // current register assignment requires more YMMs than available + // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad + if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1) + jcp.ur_w -= 1; + else + for (int b = 3; b > 1; b--) + if (jcp.nb_oc % b == 0) { + jcp.nb_oc_blocking = b; + break; + } + } + } + + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +void jit_avx2_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow, + int r_overflow) +{ + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + Label kd_loop, skip_kd_loop; + Label oc_loop, skip_oc_loop; + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(ur_w * ii + jj)); + } + + if (one_of(jcp.ndims, 3, 4)) { + cmp(reg_channel_work, 0); + jle(skip_oc_loop, T_NEAR); + xor_(reg_channel, reg_channel); + + mov(aux_reg_ddst_oc_loop, reg_ddst); + mov(aux_reg_kernel_oc_loop, reg_kernel); + + L(oc_loop); + mov(aux_reg_ddst, aux_reg_ddst_oc_loop); + mov(aux_reg_kernel, aux_reg_kernel_oc_loop); + } + + if (jcp.ndims == 5) { + assert(jcp.nb_oc_blocking == 1); + push(oi_iter); + + mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_ddst); + mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]); + + L(kd_loop); + mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_ddst, aux_reg_dst_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + Label kh_loop, skip_kh_loop; + cmp(kj, 0); + jle(skip_kh_loop, T_NEAR); + L(kh_loop); { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); // 0; + int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) { + + for (int jj = jj_start ; jj < jj_end; jj += stride_w) { + int aux_output_offset + = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2; + vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w), + ptr[aux_reg_ddst + + sizeof(float) * aux_output_offset]); + } + + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset + = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block + + ki * jcp.ic_block * jcp.oc_block + + ofm2 * jcp.ic_block; + vmovups(ymm15, + ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15); + } + } + } + add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block + * ic_block); + sub(aux_reg_ddst, sizeof(float) * ow * oc_block); + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + L(skip_kh_loop); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, + sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + } + + if (one_of(jcp.ndims, 3, 4)) { + int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow + * jcp.oc_block; + int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw + * jcp.ic * jcp.oc_block; + + add(aux_reg_ddst_oc_loop, ddst_oc_shift); + add(aux_reg_kernel_oc_loop, kernel_oc_shift); + + inc(reg_channel); + cmp(reg_channel, reg_channel_work); + jl(oc_loop, T_NEAR); + + L(skip_oc_loop); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + } + + Label no_update_label; + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int ii = 0; ii < nb_ic_block; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(Ymm(15), + make_safe_addr(reg_dsrc, offt, reg_long_offt)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(15)); + + } + } + L(no_update_label); + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt), + Ymm(ur_w * ii + jj)); + } +} + +void jit_avx2_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]); + + int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block; + int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block; + + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + int r_overflow = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad)) / jcp.stride_w); + int r_overflow1 = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) + n_oi--; + + if (jcp.ur_w == jcp.iw) { + compute_loop(jcp.ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(jcp.ur_w, l_overflow, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } else { + xor_(oi_iter, oi_iter); + if (l_overflow > 0) { + compute_loop(jcp.ur_w, l_overflow, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + } + + if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop; + L(ow_loop); { + compute_loop(jcp.ur_w, 0, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR); + } + } + + if (r_overflow1 > 0 ) { + compute_loop(jcp.ur_w, 0, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + } + + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } + + this->postamble(); +} + +status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + + int ndims = diff_src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int simd_w = 8; + + /* derivatives */ + jcp.idp = jcp.id + 2 * jcp.f_pad; + jcp.ihp = jcp.ih + 2 * jcp.t_pad; + jcp.iwp = jcp.iw + 2 * jcp.l_pad; + jcp.ohp = jcp.oh; /* do we really need */ + jcp.owp = jcp.ow; /* padded output ??? */ + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + /* gemm-based convolution performs better in these cases */ + if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1) + return status::unimplemented; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + if (jcp.oc % jcp.oc_block) return status::unimplemented; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.nb_ic_blocking = 1; + jcp.nb_oc_blocking = 1; + jcp.ur_w = 1; + + if(one_of(ndims, 3, 4) && jcp.ow < 40) + jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2; + + if (ndims == 3) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + + bool args_ok = true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i, + gOIdhw8o8i, OIdhw8o8i) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && jcp.stride_w == jcp.stride_h + && jcp.stride_d == 1 + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.ic % simd_w == 0 + && jcp.oc % simd_w == 0 + && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) return status::unimplemented; + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad; + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + + const int max_regs = 15; /* Maximun number of registers available for + result accumulation and delta dst data. + One additional register is reserved for weights + data. */ + + /* Find the best blocking with maximum number of fma instructions + per ur_w * nb_ic_blocking compute loops. Number of required registers + is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. + ur_w must be divisible by stride_w */ + if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers + distribution exceeds max_regs */ + return status::unimplemented; + + int best_nfmas = 0; + for (int b = 1; b <= 4; b++) + { + if (jcp.nb_ic % b != 0) + continue; + + for (int u = jcp.stride_w; + u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w; + u += jcp.stride_w) + { + int ur_w = nstl::min(u, jcp.iw); + /* maximum 1 step with l_overflow so far */ + if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) + continue; + int nfmas = utils::div_up(ur_w, jcp.stride_w) * b; + if (nfmas > best_nfmas + || (nfmas == best_nfmas && jcp.ur_w < ur_w)) { + jcp.ur_w = ur_w; + jcp.nb_ic_blocking = b; + best_nfmas = nfmas; + } + } + } + if (best_nfmas == 0) /* can't find appropriate blocking */ + return status::unimplemented; + + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + /* maximum 1 ur_w block with r_overflow so far */ + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + return status::success; +} + +void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +void jit_avx2_conv_bwd_weights_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + compute_oh_loop_common(); + this->postamble(); +} + +status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + const int simd_w = 8; + + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id + - jcp.f_pad); + if (ndims == 5) + if (jcp.f_pad != 0 || back_pad != 0) + return status::unimplemented; + + const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1); + const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1); + const bool boundaries_ok = true + && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad + && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad; + if (!boundaries_ok) + return status::unimplemented; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && IMPLICATION(mimo, jcp.ic % simd_w == 0) + && jcp.oc % simd_w == 0 + && jcp.kw < 14 + && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */ + && jcp.kh <= jcp.ih /* [bwd_w:r2] */ + && jcp.kd <= jcp.f_pad + jcp.id + && jcp.kd <= jcp.id + && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */ + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + return status::success; +} + +void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_loop; + mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0 + L(kd_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult); + sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block + * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + mov(kj, reg_kh); + Label kh_comeback_loop; + L(kh_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(reg_input, sizeof(float) * jcp.iw * inp_mult); + sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset, + int kernel_offset, int output_offset) +{ + const int kw = jcp.kw; + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + vmovups(Ymm(kw * ic_block_step + 0), + yword[reg_output + + sizeof(float) * i_ur * oc_block + output_offset]); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw; + if (i_iw - pad_l < 0 + || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r) + continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t i_off = (size_t)input_offset + sizeof(float)*( + one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? (i_iw - pad_l) + i_ic + * ((size_t)jcp.id * jcp.ih * jcp.iw) + : (i_iw - pad_l) * ic_block + i_ic); + vbroadcastss(Ymm(kw * ic_block_step + 1), + make_safe_addr(reg_input, i_off, reg_long_offt)); + vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic), + Ymm(kw * ic_block_step + 0), + Ymm(kw * ic_block_step + 1)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(yword[reg_kernel + off], + Ymm(i_kw * ic_block_step + i_ic)); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() +{ + int ic_block_step; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block; + } else { + ic_block_step = jcp.kw > 7 ? 1 + : jcp.kw > 3 ? 2 + : jcp.kw > 1 ? 4 : 8; + } + + const int max_ur_w = jcp.ow > 56 ? 14 : 28; + + if (jcp.ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + od_step_comeback_pointers(); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } else { + oh_step_comeback_pointers(); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad + = nstl::max(0, + (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, + 0, 0); + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, ic_block); + jl(ic_block_loop, T_NEAR); + } + if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + const int stride_w = jcp.stride_w; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad = jcp.r_pad; + + int ur_w = nstl::min(jcp.ow, max_ur_w); + int ur_w_trips = jcp.ow / ur_w; + int ur_w_tail = jcp.ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block; + + int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + if (jcp.l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, + jcp.l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) + * (ur_w * stride_w - jcp.l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + Label ow_block_loop; + L(ow_block_loop); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_loop, T_NEAR); + } + } + + if (ur_w_tail > 0) + compute_ic_block_step(ur_w_tail, + 0, r_pad, ic_block_step, 0, 0, 0); + + sub(reg_input, sizeof(float) * input_comeback); + sub(reg_output, sizeof(float) * output_comeback); + + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_loop, T_NEAR); + } + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() +{ + const int icoc_block = jcp.ic_block * jcp.oc_block; + const int t_pad = jcp.t_pad; + const int stride_h = jcp.stride_h; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + int b_pad = jcp.b_pad; + + Label oh_tpad_loop, oh_loop, oh_loop_end; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + if (t_pad > 0) { + assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */ + mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); + add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block); + + L(oh_tpad_loop); { + compute_oh_step_disp(); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + add(reg_kh, stride_h); + + /* the overlap between input and kernel may not reach kernel size. + * so far we do not support that (until we put constant here) */ + const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */ + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_loop, T_NEAR); + } + + if (t_pad % stride_h != 0) { + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block); + add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult); + } + } + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_loop, T_NEAR); + + mov(reg_kh, jcp.kh); + L(oh_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_loop, T_NEAR); + } + L(oh_loop_end); + if (b_pad > 0) { + Label oh_bpad_loop, oh_bpad_loop_end; + cmp(reg_oj, jcp.oh); + jge(oh_bpad_loop_end, T_NEAR); + + mov(reg_kh, jcp.ih + t_pad); + sub(reg_kh, reg_ih_count); + L(oh_bpad_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_loop_end, T_NEAR); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_loop, T_NEAR); + } + L(oh_bpad_loop_end); + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp new file mode 100644 index 0000000000..412c50c9ee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP +#define JIT_AVX2_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_conv_fwd_kernel_f32: public jit_generator { + jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_avx2_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32) + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t aux_reg_inp_d = r11; + reg64_t aux_reg_ker_d = abi_not_param1; + + reg64_t reg_ki = rsi; + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = r15; + reg64_t reg_long_offt = r15; + Xbyak::Reg32 reg_ci_flag = r13d; + + Xbyak::Ymm ytmp = Xbyak::Ymm(14); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void solve_common(int oc_blocks, char oc_blocks_label); + + void generate(); +}; + +struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) + + jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t reg_dsrc = rsi; + reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only + reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5 + case only */ + + reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only + reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only + + reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only + reg64_t kj = r11; + reg64_t oi_iter = r12; + reg64_t reg_kh = r14; + reg64_t reg_channel = r13; // used in ndims < 5 case only + reg64_t reg_channel_work = r9; // used in ndims < 5 case only + reg64_t reg_long_offt = r15; + + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32) + + jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_tmp = r11; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t ki = r14; + reg64_t reg_long_offt = r11; + + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset); + inline void compute_oh_step_disp(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_loop_common(); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp new file mode 100644 index 0000000000..13f61e84fe --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp @@ -0,0 +1,410 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx2_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, d, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (pd()->ndims() == 4) \ + ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ + (pd()->ndims() == 3) \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : (pd()->ndims() == 4) \ + ? wht_blk_off_(f, g, oc, ic, kh, kw) \ + : wht_blk_off_(f, g, oc, ic, kd, kh, kw) + +void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od + * jcp.oh; + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const int dj = od * jcp.stride_d; + const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); + const int d_b_overflow = nstl::max(jcp.id, dj + + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + + const int id = nstl::max(dj - jcp.f_pad + + div_up(d_t_overflow, + (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0); + + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, id, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wd, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + + const int kd_padding = jcp.kd + - div_up(d_t_overflow, (jcp.dilate_d + 1)) + - div_up(d_b_overflow, (jcp.dilate_d + 1)); + par_conv.kd_padding = nstl::max(0, kd_padding); + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + } + icbb += icb_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +void jit_avx2_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; + int ih_block_size = jcp.ih; + int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; + if (work_amount < (size_t)2 * mkldnn_get_max_threads()) { + ih_block_size = 1; + num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + work_amount *= num_ih_blocks; + } + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + size_t n{0}, g{0}, icbb{0}, ihb{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, + ihb, num_ih_blocks); + for (size_t iwork = start; iwork < end; ++iwork) { + for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) + for (int id = 0; id < jcp.id; ++id) { + auto par_conv = jit_conv_call_s(); + + const int idp = jcp.id + 2 * jcp.f_pad; + const int d_t_overflow = nstl::max(0, + jcp.kd - 1 - id - jcp.f_pad); + const int back_pad = idp - jcp.id - jcp.f_pad; + const int d_b_overflow = nstl::max(0, + jcp.kd - 1 - (jcp.id - 1 - id) - back_pad); + const int od = id + jcp.f_pad - d_b_overflow; + + int ih_start = ihb * ih_block_size; + int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); + for (int ih = ih_start; ih < ih_end; ++ih) { + + const int i_t_overflow = nstl::max(0, (jcp.kh - 1 + - ih - jcp.t_pad) / jcp.stride_h); + const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih + + ih - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ih) % jcp.stride_h); + int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; + + par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; + par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow - i_b_overflow; + par_conv.kw_padding = 0; + + const int k_lo = overflow_kh_lo + + i_b_overflow * jcp.stride_h; + const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; + + par_conv.src = &diff_src[src_blk_off(diff_src_d, n, + /*jcp.ic == 3 ? 0 :*/ + g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)]; + par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, + n, g * jcp.nb_oc + oc, od, oh, 0)]; + par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, + jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb, + d_b_overflow, k_lo, 0)]; + + par_conv.src_prf = nullptr; + par_conv.dst_prf = nullptr; + par_conv.filt_prf = nullptr; + par_conv.channel = oc; + par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc, + jcp.nb_oc_blocking); + + kernel_->jit_ker(&par_conv); + } + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, + num_ih_blocks); + } + }; + + parallel(0, ker); +} + +void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + auto ker = [&](int ithr, int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_job_start = rw->balancer().ithr_job_off(ithr); + const int w_njobs = rw->balancer().ithr_njobs(ithr); + + if (w_njobs == 0) return; + + /* reduction dimension */ + int img_od_start{0}, img_od_end{0}, img{0}, od_s{0}; + balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), img_od_start, img_od_end); + + int img_start = img_od_start, img_end = img_od_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + /* jobs */ + int g_start{0}, ocb_start{0}, icb_start{0}; + nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc, icb_start, jcp.nb_ic); + + while (img_start < img_end) { + int g = g_start, ocb = ocb_start, icb = icb_start; + + const int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int idp = jcp.id + jcp.f_pad + jcp.back_pad; + + if (id_s < idp - jcp.back_pad - jcp.kd + 1) + for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + /* TODO: put dw <-- 0 in kernel */ + if (img == img_first) + array_set(rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_, 0, + rw->balancer().job_size_); + + for (int od = od_s; od < od_e; ++od) { + const int id = od * jcp.stride_d; + if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; + + auto par_conv = jit_conv_call_s(); + par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; + par_conv.dst = + &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)]; + par_conv.filt = rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_; + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb, + jcp.nb_ic); + } + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) + d_bias[o] = 0.; + + for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp new file mode 100644 index 0000000000..bb65bce79c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP +#define CPU_JIT_AVX2_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_, + *desc(), src_md(), weights_md(), dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() < 8; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_avx2_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_fwd_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx2_convolution_bwd_data_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_data_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf( + jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + cpu_reducer_t::conf_t reducer_wei_conf_; + + protected: + bool set_default_formats() { + using namespace format_tag; + const bool flat = IC() == 3; + + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + + private: + void init_balancers() { + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = 1<<21; /* just a heuristic */ + + if(with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + + reducer_wei_conf_.init(reduce_balancer_t(max_threads, + jcp_.kd * jcp_.kh * jcp_.kw + * jcp_.ic_block * jcp_.oc_block, + jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc, + jcp_.mb * jcp_.od, max_buffer_size)); + } + }; + + jit_avx2_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , reducer_weights_(nullptr) + , reducer_bias_(nullptr) + { + kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_); + reducer_bias_ = + new cpu_reducer_t(pd()->reducer_bia_conf_); + reducer_weights_ = + new cpu_reducer_t(pd()->reducer_wei_conf_); + } + + ~jit_avx2_convolution_bwd_weights_t() { + delete kernel_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_weights_kernel_f32 *kernel_; + cpu_reducer_t *reducer_weights_, *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp new file mode 100644 index 0000000000..635b83b2bf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp @@ -0,0 +1,1255 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" +#include "cpu_barrier.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_common_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt)); + + if (jcp.ver == ver_4fma) + { + Label bcast_loop; + Label bcast_loop_wraparound; + Label bcast_loop_out; + Label bcast_loop_ur_full; + + cmp(bcast_loop_iter, jcp.ur); + jle(bcast_loop_wraparound, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jg(bcast_loop, T_NEAR); + } + + L(bcast_loop_wraparound); + if (jcp.ur_tail) { + je(bcast_loop_ur_full, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + jmp(bcast_loop_out, T_NEAR); + } + L(bcast_loop_ur_full); + reduce_loop(load_loop_blk, jcp.ur, 0, true); + L(bcast_loop_out); + } + else + { + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } + } +} + +void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load, int i_fma) { + return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step) + + jcp.fma_step * i_load + i_fma); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_out * jcp.oc_block * i_load); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + int offt; + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + offt = (i_reduce == jcp.reduce_loop_unroll) + ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll + : i_ur * jcp.reduce_loop_unroll + i_reduce; + } else { + if (jcp.transpose_src) { + const int reduce_group = i_reduce / 4; + const int reduce_shift = i_reduce % 4; + offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift; + } + else + offt = i_reduce * jcp.ic_block + i_ur; + } + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int offt; + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) + return EVEX_compress_addr(aux_reg_output_data, + (i_load * jcp.bcast_dim + i_ur) * jcp.load_block + * jcp.typesize_out); + else + return ptr[aux_reg_output_data + + (i_load + ? reg_output_stride * i_load + : 0) // TODO: Xbyak should allow 0 scale + + jcp.typesize_out * jcp.load_block * i_ur]; + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_sum) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + for (int i_ur = 0; i_ur < ur; ++i_ur) { + mic_prefetcht1(output_ptr(i_load, i_ur)); + } + } + } + + if (jcp.with_bias + && one_of(jcp.prop_kind, forward_training, forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero, T_NEAR); + + for (int i_load = 0; i_load < load_loop_blk; i_load++) + for (int i_ur = 0; i_ur < ur; ++i_ur) + vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load)); + jmp(init_done, T_NEAR); + } + + L(init_zero); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + L(init_done); + }; + + auto store = [=]() { + Label store_noadd; + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto r = vreg_accum(i_load, i_ur); + vaddps(r, r, output_ptr(i_load, i_ur)); + } + + L(store_noadd); + if (jcp.with_eltwise) { + Label store_noeltwise; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_noeltwise, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_noeltwise); + } + + auto store_output = [=](bool output_is_aligned) { + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + if (output_is_aligned && jcp.use_vmovntps) + vmovntps(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + else + vmovups(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + }; + + Label unaligned_store, end_store; + test(aux_reg_output_data, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + }; + + auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load, + bool last_block, bool wraparound, int reduce_step) + { + bool pf_ker_l1 = true; + bool pf_ker_l2 = wraparound; + int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk; + int i_op = (i_reduce / reduce_step) * ur * load_loop_blk + + i_ur * load_loop_blk + i_load; + + int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0; + int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0; + int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur; + + int pf_inp_ops = n_ops / 2; // # of operations during which to pf input + int pf_inp_trigger; + if (jcp.prop_kind == backward_weights) + pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block); + else + pf_inp_trigger = nstl::max(1, pf_inp_ops / ur); + + int n_other_pf = + load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1); + int n_other_pf_ops = n_ops - pf_inp_ops; + int other_pf_trigger + = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0; + + if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) { + // input prefetches have the highest priority b/c the + // first iteration of the kernel block touches all the + // cache lines + int i_pf = i_op / pf_inp_trigger; + auto pf_reg = wraparound && last_block + ? reg_bcast_data + : (last_block ? aux1_reg_bcast_data + : aux_reg_bcast_data); + int offt = i_pf; + if (jcp.prop_kind == backward_weights) { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.is : jcp.reduce_block); + offt *= jcp.bcast_block; + } else { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.ur : jcp.bcast_dim); + offt *= jcp.reduce_block; + } + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_op >= pf_inp_ops && n_other_pf) { + // remaining prefetches are spread among the rest of the + // operations; prefetches for output take priority + // TODO: spread L2 prefetches among L1 prefetches + i_op -= pf_inp_ops; + if (i_op % other_pf_trigger == 0) { + int i_pf = i_op / (load_loop_blk * other_pf_trigger); + if (i_pf < n_pf_ker_l2) { + int offt = (i_pf + (i_load + 1) * jcp.reduce_dim) + * jcp.load_block; + mic_prefetcht1(ptr[aux_reg_load_data + + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) { + i_pf -= n_pf_ker_l2; + auto pf_reg = last_block ? reg_load_data + : aux_reg_load_data; + int offt = (i_pf + i_load * jcp.reduce_dim + + (last_block + ? (wraparound ? jcp.reduce_dim : 0) + : jcp.reduce_block)) + * jcp.load_block; + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) { + i_pf -= n_pf_ker_l1 + n_pf_ker_l2; + int offt = i_pf * jcp.load_block; + mic_prefetcht0(ptr[aux_reg_output_data + + offt * jcp.typesize_out]); + } + } + } + }; + + auto fma_block = [=](bool last_block) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + + int reduce_step = jcp.fma_step; + + for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + // if transposed input data used and if spatial size is + // not divided by transpose step (4) then for last reduce step + // we should load only needed load_registers data + // and clear remaining + if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block + && i_reduce == jcp.reduce_loop_unroll - reduce_step) { + Label load_all; + Label load_finish; + test(reg_reduce_pos_flag, FLAG_SP_LAST); + jz(load_all, T_NEAR); + + const int n_loads = jcp.is % jcp.fma_step; + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + if (i_fma < n_loads) + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + else + vpxord(vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma)); + } + jmp(load_finish); + + L(load_all); + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + L(load_finish); + } else { + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + } + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + if (jcp.ver == ver_4fma) + v4fmaddps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, false)); + else if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), vreg_bcast); + else + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, true)); + prefetch_callback(ur, i_reduce, i_ur, i_load, + last_block, wraparound, reduce_step); + } + } + } + }; + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx512_common_1x1_conv_kernel::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + sub(rsp, stack_space_needed); + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (one_of(jcp.prop_kind, forward_training, forward_inference)) + mov(reg_relu_ns, reinterpret_cast(&jcp.eltwise.alpha)); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_weights: + for (int i_load = 0; i_load < load_loop_blk; i_load++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 }; + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 }; + + const int size_ur_cases_fma + = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + sizeof(ur_cases_fma_expl_bcast) : + sizeof(ur_cases_fma_embd_bcast); + const int size_ur_cases_4fma = sizeof(ur_cases_4fma); + + const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + ur_cases_fma_expl_bcast : + ur_cases_fma_embd_bcast; + const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma; + const int num_ur_cases = + (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma) + / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_common)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, + format_kind::undef, cd.diff_bias_desc.format_kind) + != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c); + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + jcp.transpose_src = false; + + if (everyone_is(data_type::f32, src_d.data_type(), + weights_d.data_type(), dst_d.data_type())) + { + const int is_bwd_d = jcp.prop_kind == backward_data; + format_tag_t wei_tag = with_groups + ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, + gOIhw16i16o, gIOhw16o16i) + : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, + OIhw16i16o, IOhw16o16i); + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) && + ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4 + == 0) { + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops) + && !reduce_src + /* Heuristic condition for relation of src size to oc. Otherwise + the src transposition overhead exceed the benefit from 4fma + */ + && ((jcp.is * jcp.ic) / jcp.oc <= 2048) + && mkldnn_thr_syncable() + ) + { + jcp.transpose_src = true; + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else { + jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma; + jcp.fma_step = 1; + } + jcp.typesize_in = sizeof(prec_traits::type); + jcp.typesize_out = sizeof(prec_traits::type); + } else { + return status::unimplemented; + } + + /* once all the formats are set, check the padding consistency */ + args_ok = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + const int SMALL_SPATIAL = 10; + const int BIG_SPATIAL = 28; + const int BIG_REDUCE_DIM = 1024; + const int BIG_LOAD_DIM = 256; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + int reduce_blocking_max{ 0 }; + + jcp.load_grp_count = 1; + + const int L1_capacity = get_cache_size(1, true) / sizeof(float); + const int L2_size = get_cache_size(2, true) / sizeof(float); + const int L2_capacity = (L2_size * 3) / 4; + + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + } else { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.ic_block; + + jcp.bcast_dim = jcp.os; + } + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + // adjusting registry blocking + int max_regs, min_regs, size_treshold, ur_step; + const int spatial + = (one_of(jcp.prop_kind, forward_training, forward_inference)) ? + jcp.oh : + jcp.ih; + if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) { + max_regs = 9; + min_regs = 6; + size_treshold = 14; + ur_step = 1; + jcp.expl_bcast = true; + + if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM + && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { + max_regs = 6; + min_regs = 5; + } + } else { + max_regs = jcp.ver == ver_4fma ? 28 : 30; + min_regs = 9; + size_treshold = jcp.ver == ver_4fma ? 28 : 14; + ur_step = jcp.ver == ver_4fma ? 4 : 1; + jcp.expl_bcast = false; + jcp.use_vmovntps = true; + } + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i -= ur_step) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.bcast_block = jcp.ur; + + jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_iter_step = jcp.load_block; + + if (jcp.prop_kind == backward_data) + jcp.loop_order = loop_lbr; + else + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + int nb_load = div_up(jcp.load_dim, jcp.load_block); + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast) { + if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL + && spatial < BIG_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 80); + else if (spatial > SMALL_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 512); + else + reduce_blocking = nstl::min(jcp.reduce_dim, 256); + + if ((jcp.mb > 28 && spatial >= 28) + || (jcp.mb > 112 && spatial >= 17)) + jcp.use_vmovntps = true; + else + jcp.use_vmovntps = false; + } else { + + reduce_blocking = nb_reduce; + if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + else if (spatial > SMALL_SPATIAL + && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 8; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + // Check input data cache aliasing. + // For other ISA constants may be updated. + // 64 * 1024 is chosen due to 1MB L2 16-way cache. + // 7 is empirical value. It is about half of 16. + // So we leave about half of the set for other data - weights, dst + int way_size = (64 * 1024) / jcp.typesize_in; + int max_hits = 7; + if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) { + int nrb = reduce_blocking / simd_w; + int sp = jcp.bcast_dim; + int wl = way_size / simd_w; + for (int start_off = 0; start_off < jcp.ur; start_off++) { + for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { + if (off % sp >= jcp.ur || ++hits < max_hits) + continue; + int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); + reduce_blocking + = nstl::min(reduce_blocking, max_r_blocking); + break; + } + } + } + + if (reduce_blocking < jcp.reduce_dim) { + jcp.use_vmovntps = false; + if (jcp.prop_kind == backward_data) + jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; + else + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + } + load_blocking = jcp.load_dim; + + int load_size = jcp.load_dim * jcp.reduce_dim; + int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; + + if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads + && nb_load * nb_bcast > nthreads) { + // Some heuristic here + float calc_koef = 0.01, best_cost = FLT_MAX; + int n_lgc = nthreads; + float ratio = (float)load_size / (float)bcast_size; + int best_lgc = ratio > 1 ? n_lgc : 1; + auto calc_job_cost = [&](int lb, int tg, float mem_k) { + int bb_size = jcp.mb * div_up(nb_bcast, tg); + float calc_size = (float)(bb_size * jcp.ur) + * (lb * jcp.load_block) * jcp.reduce_dim; + float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) + * jcp.reduce_dim; + return calc_koef * calc_size + mem_k * mem_size; + }; + for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { + lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; + int min_lb = nb_load / lgc; + int max_lb = div_up(nb_load, lgc); + int min_tg = nthreads / lgc; + int max_tg = div_up(nthreads, lgc); + // Some heuristic here + float mem_koef = (max_tg == 1) ? 1.f : 1.3f; + float job_cost = 0.; + if (nthreads % lgc < nb_load % lgc) { + job_cost = calc_job_cost(max_lb, min_tg, mem_koef); + } else { + auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); + auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); + job_cost = nstl::max(job_cost1, job_cost2); + } + + if (job_cost < best_cost) { + best_lgc = lgc; + best_cost = job_cost; + } + } + jcp.load_grp_count = best_lgc; + load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } else { + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + } + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64 + && load_size >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); + load_blocking = jcp.load_block; + } + + if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim + && jcp.oh * jcp.ow > 64 + && IMPLICATION(reduce_src, jcp.load_dim < 1024)) { + /* Looking for best loading dimension blocking + * to get the best thread and data read/write efficiency + * by finding the optimal 'load_chunk' value + * Example: + * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512 + * the 'best' load_chunk value should be 1 + * TODO: remove heuristic constants in above condition + * TODO: check this blocking for other ISA + */ + float best_eff = -1.f; + int best_lgc = 1; + + for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { + int lgc = div_up(nb_load, load_chunk); + if (lgc > nthreads) + continue; + int thr_per_grp = div_up(nthreads, lgc); + int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp) + * jcp.bcast_block; + int load_per_thr = load_chunk * simd_w; + float data_norm = (bcast_per_thr + load_per_thr) / 2.f; + float data_eff = (bcast_per_thr * load_per_thr) + / (data_norm * data_norm); + float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc) + / div_up(nthreads, lgc); + float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) + / rnd_up(jcp.mb * nb_bcast, thr_per_grp); + float thr_eff = thr_eff_over_grp * thr_eff_in_grp; + float load_eff = (float)nb_load / rnd_up(nb_load, lgc); + float overall_eff = data_eff + thr_eff + load_eff; + if (overall_eff > best_eff) { + best_eff = overall_eff; + best_lgc = lgc; + } + } + jcp.load_grp_count = best_lgc; + load_blocking + = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) + * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + } else if (jcp.prop_kind == backward_weights) { + + jcp.use_vmovntps = false; + if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma) + jcp.use_vmovntps = true; + + if (jcp.transpose_src) + jcp.reduce_dim = jcp.tr_is; + else + jcp.reduce_dim = jcp.is; + + if (jcp.ver == ver_4fma) { + // reduce_block should be divided by fma_step + jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4); + } else { + jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); + if (jcp.reduce_dim % jcp.reduce_block != 0) + jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); + if (jcp.reduce_block > 256) { + jcp.reduce_block = 1; + } + + } + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) { + // if reduce_block is big then generated JIT code may be big + // for small values of ur because reduce_loop_unroll = reduce_block + jcp.ur = jcp.bcast_block / 2; + jcp.expl_bcast = true; + } else { + jcp.ur = jcp.bcast_block; + jcp.expl_bcast = false; + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in; + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = + jcp.oc_block * jcp.ic_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = + jcp.oc_block * jcp.ur * jcp.typesize_out; + jcp.bcast_loop_bcast_step = + jcp.ic_block * jcp.reduce_dim * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in; + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + balance(jcp, nthreads); + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + load_blocking = best_divider(load_blocking, 16, load_blocking, false); + load_blocking *= jcp.load_block; + + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + int min_bcast_blocking = 5; + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + bcast_blocking = best_divider( + bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + // for reduction balance + if (jcp.ver == ver_avx512_core) { + int max_reduce_blocking + = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); + int min_reduce_blocking = nstl::min( + L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); + reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, + max_reduce_blocking, true); + reduce_blocking + = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), + jcp.reduce_block); + } else { + int max_reduce_blocking = L2_capacity + / ((bcast_blocking + load_blocking) * jcp.reduce_block); + max_reduce_blocking = nstl::min(max_reduce_blocking, + (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block); + + int num_jobs = div_up(jcp.load_dim, load_blocking) + * div_up(jcp.bcast_dim, bcast_blocking); + int threads_per_job = nstl::max(1, nthreads / num_jobs); + reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block); + reduce_blocking = div_up(reduce_blocking, threads_per_job); + + reduce_blocking = best_divider(reduce_blocking, + max_reduce_blocking - 2, max_reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + if (jcp.ver == ver_4fma) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + } + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx512_common_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.with_bias + && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); + + if (jcp.prop_kind == backward_weights) { + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic; + scratchpad.book(key_conv_wei_reduction, + jcp.typesize_out * wei_size * (jcp.nthr_mb - 1)); + } + + if (jcp.transpose_src) { + const size_t tr_src_size = + (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is; + scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size); + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * jcp.nthr); + } +} + +void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp, + int nthreads) +{ + // initialize jcp reduction threading properties + jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; + if (nthreads < jcp.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + const int nb_load = div_up(jcp.load_dim, jcp.load_block); + const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + jcp.nthr_g = jcp.ngroups; + const int nthr = nthreads / jcp.nthr_g; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level + * optimizer tries to minimize memory consumption. few notes: (n1) + * unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + int bcast_koeff = 1; + int load_koeff = 1; + int output_koeff = 12; + if (jcp.transpose_src) { + bcast_koeff = 5; + load_koeff = 1; + output_koeff = 8; + } + return 0 + + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block + / jcp.stride_h / jcp.stride_w /* (n1) */ + + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block + + (size_t)output_koeff /* (n2) */ + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block + * jcp.oc_block; + }; + + int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; + auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); + for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); + for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); + auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + jcp.nthr_mb = nthr_mb; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) + jcp.nthr_mb = nstl::min(jcp.mb, nthreads); + + jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; + assert(jcp.nthr <= nthreads); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp new file mode 100644 index 0000000000..d2ae017943 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp @@ -0,0 +1,108 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP +#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_1x1_conv_kernel : public jit_generator { + jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_common_1x1_conv_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel) + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + + reg64_t reg_bcast_data = r8; + reg64_t reg_load_data = r10; + reg64_t reg_output_data = r9; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = aux_reg_load_data; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reg_load_loop_work = rsi; + reg64_t reg_reduce_loop_work = r11; + reg64_t bcast_loop_iter = rdx; + reg64_t reduce_loop_iter = abi_param1; + reg64_t reg_reduce_pos_flag = rax; + reg64_t reg_output_stride = r13; + reg64_t reg_bias_data = r12; + reg64_t reg_relu_ns = r13; + reg64_t reg_bcast_loop_work = aux1_reg_bcast_data; + + Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + int bcast_loop_work_offt = 0; + int stack_space_needed = 16; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + static void balance(jit_1x1_conv_conf_t &jcp, int nthreads); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp new file mode 100644 index 0000000000..54d58c8a39 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp @@ -0,0 +1,816 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_common_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + + +namespace { +template +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const int grp_count = nstl::min(nx_divider, nthr); + const int grp_size_big = nthr / grp_count + 1; + const int grp_size_small = nthr / grp_count; + const int n_grp_big = nthr % grp_count; + const int threads_in_big_groups = n_grp_big * grp_size_big; + + const int ithr_bound_distance = ithr - threads_in_big_groups; + T grp, grp_ithr, grp_nthr; + if (ithr_bound_distance < 0) { // ithr in first groups + grp = ithr / grp_size_big; + grp_ithr = ithr % grp_size_big; + grp_nthr = grp_size_big; + } else { // ithr in last groups + grp = n_grp_big + ithr_bound_distance / grp_size_small; + grp_ithr = ithr_bound_distance % grp_size_small; + grp_nthr = grp_size_small; + } + + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} +/* convolution forward */ + +template +void jit_avx512_common_1x1_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const auto &jcp = kernel_->jcp; + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template +void jit_avx512_common_1x1_convolution_fwd_t:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + + const int ndims = src_d.ndims(); + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + }; + + auto init_reduce = [&](int icb) + { + const int nb_ic_blocking_step = + nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, + int ih, int iw) + { + + const int _ocb = g * nb_oc + ocb; + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + p.bias_data = &bias[_ocb * jcp.oc_block]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + + +template struct jit_avx512_common_1x1_convolution_fwd_t; +/* convolution backward wtr data */ + +template +void jit_avx512_common_1x1_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).template get( + key_conv_rtus_space); + + const int ndims = diff_src_d.ndims(); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + parallel(0, [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, icb_start, icb_end, jcp.load_grp_count); + + bool reduce_outer = (jcp.loop_order == loop_rbl + || jcp.loop_order == loop_rlb); + int nboc_outer = reduce_outer ? nb_oc : 1; + int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; + + int nboc_inner = reduce_outer ? 1 : nb_oc; + int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; + + for (int ocb_outer = 0; ocb_outer < nboc_outer; + ocb_outer += ocb_outer_step) { + size_t cur_ocb_outer = + nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer; + + int load_step = 0; + for (int icb = icb_start; icb < icb_end; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, + icb_end * jcp.ic_block, load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = bcast_start; iwork < bcast_end; + iwork += bcast_step) + { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb_inner = 0; ocb_inner < nboc_inner; + ocb_inner += ocb_inner_step) { + int cur_ocb_inner = + nstl::min(ocb_inner + ocb_inner_step, nboc_inner) - + ocb_inner; + + int ocb = reduce_outer ? ocb_outer : ocb_inner; + int nb_oc_blocking_step = reduce_outer + ? cur_ocb_outer : cur_ocb_inner; + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, + jcp.oc, nb_oc_blocking_step * jcp.oc_block); + + kernel_->jit_ker(&p); + } + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + } + }); +} + +template struct jit_avx512_common_1x1_convolution_bwd_data_t; + +/* convolution backward wtr weights */ + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +jit_avx512_common_1x1_convolution_bwd_weights_t :: + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) + , trans_kernel_(nullptr), rtus_driver_(nullptr) +{ + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + acc_ker_ = new cpu_accumulator_1d_t(); + reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); + init_rtus_driver(this); + + const auto &jcp = kernel_->jcp; + + if (jcp.transpose_src) { + auto tp = jit_transpose4x16_src_t(); + tp.src_pf0_distance = 4; + tp.tr_src_pf0_distance = 0; + tp.src_pf1 = true; + tp.tr_src_pf1 = false; + trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp); + } +} + +void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + const auto scratchpad = this->scratchpad(ctx); + + auto rtus_space = scratchpad.get(key_conv_rtus_space); + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; + auto wei_reduction = scratchpad.get(key_conv_wei_reduction); + + /* prepare src transposition barriers */ + auto tr_src = scratchpad.get(key_conv_tr_src); + auto tr_src_bctx = scratchpad.get( + key_conv_tr_src_bctx); + if (jcp.transpose_src) { + for (int i = 0; i < jcp.nthr; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + + const int ndims = src_d.ndims(); + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic; + + simple_barrier::ctx_t reduction_barrier; + simple_barrier::ctx_init(&reduction_barrier); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + + const int sp_nb = jcp.nb_reduce; + const int mb_sp_work = jcp.mb * sp_nb; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + // TODO: use memory descriptor with the same fmt as src + // (or use a macro :)) + auto tr_src_off = [&](int img, int icb, int is) { + const size_t tr_chn_size = jcp.tr_is * jcp.ic_block; + const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups; + return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block; + }; + + auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size, + int g_start, int g_work, int ic_b_start, int ic_b_work, + int ithr, int nthr, int first_ic_b) + { + const int work_amount = g_work * ic_b_work; + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int g{ 0 }, ic_b{ 0 }; + nd_iterator_init(start, g, g_work, ic_b, ic_b_work); + g += g_start; + const int ic_b_tr = g * nb_ic + first_ic_b + ic_b; + ic_b += ic_b_start; + + const int _ic = g * nb_ic + ic_b; + + const int is = sp_b_start * jcp.reduce_block; + const int ih = is / jcp.iw; + const int iw = is % jcp.iw; + + const int src1_off = data_blk_off(src_d, img, _ic, ih, iw); + data_t *src1 = (data_t *)&src[src1_off]; + data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.is * jcp.ic_block; + const int tr_src_stride = jcp.tr_is * jcp.ic_block; + + const int my_work = end - start; + for (int iwork = 0; iwork < my_work; iwork++) { + auto par_trans = jit_src_transpose_s(); + assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4); + par_trans.size = sp_size; + par_trans.src = src1; + par_trans.tr_src = tr_src1; + par_trans.src_prf = src1 + 64 * 16; + par_trans.tr_src_prf = tr_src1 + 80 * 16; + trans_kernel_->jit_ker(&par_trans); + + src1 += src_stride; + tr_src1 += tr_src_stride; + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1)); + + const int ithr_ic_b = ithr % jcp.nthr_ic_b; + const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; + const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; + const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / + jcp.nthr_g; + + const int ithr_but_oc + = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b; + + /* reduction dimension */ + int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 }; + if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) { + // it's preferable to parallelize by mb if possible + int img_start{ 0 }, img_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end); + mb_sp_b_start = img_start * sp_nb; + mb_sp_b_end = img_end * sp_nb; + } + else { + balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, + mb_sp_b_end); + } + + /* independent dimensions */ + int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 }; + int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 }; + + balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); + balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, + oc_b_end); + balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, + ic_b_end); + + const int g_work = g_end - g_start; + const int oc_b_work = oc_b_end - oc_b_start; + const int ic_b_work = ic_b_end - ic_b_start; + + data_t *diff_wei = ithr_mb == 0 + ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size; + + int sp_b_step = 0; + for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; + mb_sp_b += sp_b_step) { + int img{ 0 }, sp_b{ 0 }; + nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); + sp_b_step = step(jcp.nb_reduce_blocking, + nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), + jcp.nb_reduce_blocking_max); + + for (int g = g_start; g < g_end; ++g) { + int load_step = 0; + int bcast_step = 0; + for (int ic_b = ic_b_start; ic_b < ic_b_end; + ic_b += bcast_step) { + bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, + jcp.nb_bcast_blocking_max); + if (jcp.transpose_src) { + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + const int sp_size + = nstl::min(sp_b_step * jcp.reduce_block, + jcp.is - sp_b * jcp.reduce_block); + uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b, + bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start); + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + } + + for (int oc_b = oc_b_start; oc_b < oc_b_end; + oc_b += load_step) { + load_step = step(nb_oc_blocking, oc_b_end - oc_b, + jcp.nb_load_blocking_max); + const int _ic_b = g * nb_ic + ic_b; + const int _ic_b_tr = g * nb_ic + ic_b_start; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + store_to = diff_wei + off; + + const data_t *diff_src = jcp.transpose_src ? + &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] : + &src[src_d.blk_off(img, _ic_b)]; + + int sp_b_end = sp_b + sp_b_step; + const data_t *pdiff_dst + = &diff_dst[diff_dst_d.blk_off(img, _oc_b)]; + const data_t *local_src = diff_src; + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t::call_params_t(); + + p.output_stride + = jcp.ic * jcp.oc_block * jcp.typesize_out; + + p.load_dim = load_step * jcp.oc_block; + + p.bcast_dim = bcast_step * jcp.ic_block; + rp.icb = bcast_step; + p.output_data = store_to; + + p.reduce_dim = sp_b_step * jcp.reduce_block; + rp.os = p.reduce_dim; + + p.first_last_flag = 0 + | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0) + | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); + + int sp = sp_b * jcp.reduce_block; + p.load_data = pdiff_dst + sp * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + sp * jcp.ic_block; + + if (ndims == 3) + rp.src = local_src + iw + * src_d.blocking_desc().strides[2]; + else + rp.src = local_src + ih + * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = local_src + sp * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + } + + /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ + if (jcp.nthr_mb > 1) { + simple_barrier::barrier(&reduction_barrier, jcp.nthr); + const int work = g_work * oc_b_work * ic_b_work; + int start{ 0 }, end{ 0 }; + balance211(work, jcp.nthr_mb, ithr_mb, start, end); + if (start == end) + return; + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + int w = start; + int sub_g_start{ 0 }, sub_oc_b_start{ 0 }, + sub_ic_b_start{ 0 }; + nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, + oc_b_work, sub_ic_b_start, ic_b_work); + while (w < end) { + const int g = g_start + sub_g_start; + const int oc_b = oc_b_start + sub_oc_b_start; + const int ic_b = ic_b_start + sub_ic_b_start; + + const int acc_size + = nstl::min(end - w, ic_b_work - sub_ic_b_start) + * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + data_t *d = diff_weights + off; + data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, g_work, + sub_oc_b_start, oc_b_work, sub_ic_b_start, + ic_b_work); + } + } + } + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) + return; + + /* reduction dimension */ + int img_start{ 0 }, img_end{ 0 }; + + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{ 0 }, ocb_start{ 0 }; + nd_iterator_init( + b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_load + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + ker(ithr, jcp.nthr); + if (pd()->with_bias()) + ker_bias(ithr, jcp.nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp new file mode 100644 index 0000000000..2e9fda76d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx512_common_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_fwd_t); + + status_t init() { + using namespace utils; + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(), + mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = + new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_common_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_fwd_f32_t + = jit_avx512_common_1x1_convolution_fwd_t; + +template +struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_common_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + + private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_bwd_data_f32_t + = jit_avx512_common_1x1_convolution_bwd_data_t; + +struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t +{ + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t::conf_t reducer_bia_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, + jcp_.mb, max_buffer_size)); + } + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx512_common_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + delete reducer_bias_; + delete rtus_driver_; + delete trans_kernel_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + + private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; + jit_transpose4x16_src *trans_kernel_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp new file mode 100644 index 0000000000..235fb02fef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp @@ -0,0 +1,4539 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" + +#include "jit_avx512_common_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) +#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + +constexpr auto small_spatial = 14; +unsigned int L1_cache_size = get_cache_size(1, true); + +inline void pick_loop_order(jit_conv_conf_t &jcp) { + using namespace prop_kind; + assert(one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)); + auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; + auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; + + // ow-threading is currently implemented for forward only + // TODO: single code for fwd and bwd after ow-thr for bwd + // meaningless switch was removed + if (jcp.prop_kind == backward_data) { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cgn : loop_gnc; + } else { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cwgn : loop_gncw; + } +} + +inline bool is_1stconv(const jit_conv_conf_t &jcp) { + if (mayiuse(avx512_core)) + return (jcp.ic < 16 && jcp.ngroups == 1); + else + return one_of(jcp.ic, 1, 3); +} + +inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { + return (jcp.nb_ow > 1); +} + +inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) { + return (jcp.ver == ver_4fma && is_ow_threading_on(jcp)); +} + +} + +template +void _jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + if (!is_owb_prefetching(jcp)) { + size_t aux_output_offset = get_output_offset(j, k); + mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) +{ + Label no_update_label, store_label, eltwise_label; + + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + if (jcp.with_bias) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + } + + if (!jcp.with_sum) { + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + } + + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = get_output_offset(j, k); + vaddps(vmm, + make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); + } + + if (!jcp.with_sum) { + jmp(eltwise_label, T_NEAR); + } else { + cmp(reg_channel, 0); + jne(eltwise_label, T_NEAR); + } + + L(no_update_label); + if (jcp.with_bias) { + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + int bias_offset = jcp.typesize_out * k * jcp.oc_block; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset)); + } + mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); + } + } + + L(eltwise_label); + if (jcp.with_eltwise) { + cmp(reg_channel, jcp.nb_ic - 1); + jl(store_label, T_NEAR); + + if (ur_w == jcp.ur_w) { + eltwise_injector_->compute_vector_range(0, + jcp.nb_oc_blocking * jcp.ur_w); + } else { + for (int k = 0; k < jcp.nb_oc_blocking; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); + } + } + + L(store_label); + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = (size_t)typesize * + ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; + vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, + reg_out_long_offt), vmm); + if (!is_owb_prefetching(jcp)) + mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ + assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0); + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + Label kh_label, kd_label; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in + * ((size_t)(kw + ur_w * stride_w - pad_l) + + (size_t)ic_block * iw * ih * jcp.id); + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + + L(kd_label); + } + mov(reg_kj, reg_kh); + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + L(kh_label); + for (int ki = 0; ki < kw; ki += 4) { + for (int ic = 0; ic < ic_block; ic++) { + for (int i = 0; i < 4; i++) { + int aux_ker_offset + = jcp.typesize_in + * ((ki + i) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block); + if (ki + i < kw) + vmovups(vmm_ker(i), + EVEX_compress_addr(aux_reg_ker, aux_ker_offset)); + else + vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i)); + } + + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + + for (int j = j_start, prf_count=0; j < j_end; j++) { + size_t aux_input_offset = (size_t)jcp.typesize_in + * ((size_t)(ki + j * stride_w + - pad_l) + (size_t)ic * iw * ih * jcp.id); + v4fmaddps(vmm_out(j, 0), vmm_ker(0), + EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, + reg_long_offt)); + if (ki + prf_count < kw && prf_count < 4 + && ((ki < 2 && j % 4) || j % 2)) { + int aux_ker_offset = jcp.typesize_in + * ((ki + prf_count) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_ker_offset)); + prf_count++; + } + if (ki == 0 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, + aux_input_offset, reg_long_offt)); + } + if (ki == 1 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); + } + } + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block); + add(aux_reg_inp, jcp.typesize_in * iw); + add(aux_reg_inp_prf, jcp.typesize_in * iw); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block); + add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block; + int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; + + bool check_last_kh = (jcp.kh > 3); + bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28); + + int oi_ipref_t0 = get_ow_start(0, pad_l); + int ow_end_ipref = get_ow_end(ur_w, 0, pad_r); + + assert(jcp.oc % jcp.nb_oc_blocking == 0); + + auto kernel_offset = [=](int ocb, int ic, int ki) { + int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int ic_offset = ic * jcp.oc_block; + return typesize * (blk_offset + ic_offset); + }; + auto kernel_loads = [=](int ki, int ic, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, ic + ii, ki); + vmovups(vmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) { + int aux_inp_offset + = typesize + * ((oi_ipref_t0 * stride_w - pad_l) * ic_block + + (jcp.dilate_h + 1) * jcp.iw * ic_block); + prefetcht0(EVEX_compress_addr(aux_reg_inp, + aux_inp_offset)); + oi_ipref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_ker_prf, reg_ker_prf); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + int kw = jcp.kw; + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1 + && ki == kw - 1 && (ic + 4) == ic_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + + if (oi % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, + prf_count_t0 + ic + 4 + - ic_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, ic + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, ic + + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_inp_next_kh(ki, 2, prf_count_t0, + prf_count_t1); + } + + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + if (oi % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(0, + prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), + prf_count_t1 = 0, prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + + if (!is_owb_prefetching(jcp)) { + if ((oi % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } else { + if (!(ki == 0 && ic == 0) + && !(ki == kw-1 && ic == 0) && + (oi % 2) && (prf_count_t1 < 4) + ) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker, kernel_offset(kk, + ic + 4 + prf_count_t0, ki))); + prf_count_t0++; + } + } + if (!is_owb_prefetching(jcp)) { + if (pref_current_inp) { + if (ki == 0 && ic == 0 && kk == 0) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + shift_input_ptr)); + } else { + if (ki == 1 && ic == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp_prf, aux_input_offset)); + } + } else { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift + = jcp.typesize_in * ur_w * stride_w * inp_mult; + bool kk_pref_slot = kk ? oi % 2 : !(oi % 2); + if (ki == 0 && ic == 0 && kk_pref_slot) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + + if (ki == kw - 1 && ic == 0 && kk_pref_slot) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + } + } + } + } + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + add(aux_reg_ker_prf, shift_kernel_ptr); + add(aux_reg_inp_prf, shift_input_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, + int pad_l, int pad_r) +{ + bool prf_ker = true; + bool prf_inp = true; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + int id = jcp.id; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = ic_block * nb_oc_block * kw; + int num_ker_prfs = prf_ker ? num_ker_loads : 0; + int num_inp_prfs = prf_inp ? + ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) : + 0; + if (jcp.is_1stconv && prf_inp) { + num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block; + } + int num_prfs = num_ker_prfs + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w; + int prf_inst_spacing + = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1; + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int ic = 0; ic < ic_block; ic++) { + int aux_kernel_offset = 0; + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + aux_kernel_offset = get_kernel_offset(ki, ic, 0, i); + vmovups(vmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + aux_kernel_offset + = get_kernel_offset(ki, ic, 0, load_offset); + vmovups(vmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth); + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + for (int j = j_start; j < j_end; j++) { + size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); + auto addr = EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true); + vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr); + int fma_idx = step * ur_w + j; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (prf_ker && !ker_prf_inserted + && ker_prfs < num_ker_prfs) { + int ker_prf_offset + = jcp.typesize_in * ker_prfs * jcp.oc_block; + mic_prefetcht2(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else if (prf_inp) { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + size_t inp_prf_stride = nstl::max(kw, stride_w); + size_t inp_prf_offset; + if (!jcp.is_1stconv) { + inp_prf_offset + = ic_block * jcp.typesize_in + * ((inp_prf_idx / kw) + * inp_prf_stride + + (inp_prf_idx % kw)); + } else { + size_t ic_prf_stride = + (size_t)jcp.typesize_in * iw * ih * id; + size_t iw_prf_stride + = jcp.typesize_in * jcp.simd_w; + inp_prf_offset = ((inp_prf_idx / ic_block) + * iw_prf_stride + + (inp_prf_idx % ic_block) + * ic_prf_stride); + } + mic_prefetcht0(EVEX_compress_addr_safe( + aux_reg_inp_prf, inp_prf_offset, + reg_long_offt)); + } + } + } + } + step++; + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block); + if (prf_ker) + add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block); + add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + if (prf_inp) + add(aux_reg_inp_prf, + jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, + int pad_l, int pad_r) +{ + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block + * jcp.ic_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * inp_mul; + + + auto input_offset = [=](int oi, int ic, int ki) { + return (size_t)jcp.typesize_in + * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * inp_mul + (size_t)ic + * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + for (int ic = 0; ic < ic_block; ic++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t aux_input_offset = input_offset(jj, ic, ki); + vbroadcastss(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt)); + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = jcp.typesize_in + * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block + * oc_block + ki * ic_block * oc_block + ic * oc_block); + if (jj_end - jj_start > 0) + vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj++) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(vmm_out(jj, ii), + vmm_inp(jj, nb_oc_block), vmm_wei); + else { + size_t aux_input_offset = input_offset(jj, ic, ki); + vfmadd231ps(vmm_out(jj, ii), vmm_wei, + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true)); + } + } + } + } + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + } +} + +template +void _jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w, + int pad_l, int pad_r) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + } + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + + if (jcp.ver == ver_4fma) + if(jcp.is_1stconv) + compute_loop_4fma_1st(ur_w, pad_l, pad_r); + else + compute_loop_4fma(ur_w, pad_l, pad_r); + else if (jcp.ver == ver_fma) + if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast) + || mayiuse(avx512_mic)) + compute_loop_fma(ur_w, pad_l, pad_r); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1) + compute_loop_fma(ur_w, pad_l, pad_r); + else + compute_loop_fma_core(ur_w, pad_l, pad_r); + else + assert(!"unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +template +void _jit_avx512_common_conv_fwd_kernel::generate() +{ + int iw = jcp.iw; + int ow = jcp.ow; + int ow_block = jcp.ow_block; + int nb_ow = jcp.nb_ow; + int kw = jcp.kw; + int l_pad = jcp.l_pad; + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; + int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; + int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; + + preamble(); + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + int r_pad = nstl::max( + 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); + int n_oi = ow / ur_w; + int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + + if (!is_ow_threading_on(jcp)) { + // ow is being processed as a whole - with left and right paddings + if (r_pad1 > 0) n_oi--; + + if (ow == ur_w) { + mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]); + mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]); + compute_loop(ur_w, l_pad, r_pad); + } else { + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + if (n_oi == 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, r_pad1); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } else { + xor_(reg_oi, reg_oi); + if (l_pad > 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + inc(reg_oi); + } + if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; + Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; + + assert(ow_block % ur_w == 0); + int n_oi_not_last_ow_block = ow_block / ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + + int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w; + + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + + mov(reg_oi, n_oi_first_ow_block); + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + + if (l_pad > 0) { + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + + L(middle_ow_blocks_label); + + if (l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + add(reg_inp_prf, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + L(oi_loop_start_label); + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + jmp(oi_loop_start_label, T_NEAR); + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + // that is last block + if (!last_ow_block_padded) { + jmp(tail_label, T_NEAR); + } + + // last oi block with right padding + L(last_oi_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + L(tail_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, int nthreads) +{ + using namespace prop_kind; + + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const int regs = 28; + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.is_1stconv = is_1stconv(jcp); + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + const int full_simd_w = cpu_isa_traits::vlen / sizeof(float); + jcp.simd_w = full_simd_w; + bool ok_to_try_xmm = true + && mayiuse(avx512_core) + && src_d.data_type() == data_type::f32 + && !jcp.is_1stconv + && !ok_to_pad_channels + && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0) + && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0); + if (ok_to_try_xmm) + jcp.simd_w = 4; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + jcp.aligned_threads = 0; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0; + if (!args_ok) + return status::unimplemented; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto src_tag = jcp.is_1stconv + ? pick(ndims - 3, ncw, nchw, ncdhw) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)); + auto dst_tag = (jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o) + : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o)); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, x)); + } + + if (mayiuse(avx512_common) && + src_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && dst_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + + if (jcp.is_1stconv) { + // TODO: fix & remove constraints below + bool not_for_4fma + = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad), + nstl::max(jcp.kw, jcp.kh) < 7); + bool is_dilated + = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); + if (one_of(true, not_for_4fma, is_dilated)) + jcp.ver = ver_fma; + if (jcp.ver == ver_4fma) { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o) + : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o)); + } else { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o) + : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o)); + } + } + } else { + return status::unimplemented; + } + + if (weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.is_1stconv) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + // avx512_core guard - just to avoid possible regression for other archs + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + for (int ur_w = regs; ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } + } + // TODO (Tanya): currently applied to Segnet convolutions only. + // Need to try for other topologies + if (jcp.ow > 150 && jcp.ur_w < regs/2) + jcp.ur_w = regs; + + int n_oi = (jcp.ow / jcp.ur_w); + int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + if (jcp.l_pad > 0 && r_pad > 0) + n_oi--; + + bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0 + && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw; + int mult = 1; + if (jcp.l_pad > 0) mult += 1; + if (r_pad > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) { + jcp.ur_w = ur_w; + break; + } + } + } + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off + = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ? + jcp.ic : + 1; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + auto is_ow_threading_applicable = [=]() { + return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4) + && IMPLICATION(mayiuse(avx512_mic), + jcp.ver == ver_4fma + && IMPLICATION(jcp.mb != 1, + jcp.ih == 1 && jcp.kh == 1))); + }; + + if (jcp.ver == ver_4fma && !jcp.is_1stconv) { + if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8 + && jcp.oh <= 8 && jcp.ow == jcp.oh) + || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) { + if (jcp.nb_oc % 2 == 0) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } else { + for (int i = jcp.nb_oc; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { + jcp.nb_oc_blocking = i; + break; + } + } + if (jcp.ver == ver_4fma && is_ow_threading_applicable()) { + if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow + && jcp.ow != 2 * jcp.ur_w) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } + } + + jcp.ow_block = jcp.ow; + + auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { + int nb_ow = div_up(jcp.ow, ow_block); + int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); + int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; + float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); + float thr_eff = disbalance * (float)work_amount + / rnd_up(work_amount, nthreads); + return thr_eff; + }; + + auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { + int res_ow_block = jcp.ow; + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + if (!is_ow_threading_applicable()) + return res_ow_block; + + int L2_part = (get_cache_size(2) * 7 / 8) / typesize; + if (jcp.ver == ver_4fma) + L2_part /= 2; + int size_src_chunk = jcp.ic_block * ur_w * jcp.kh; + int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w; + int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block + * jcp.kw * jcp.kh; + int nurw_cache = (L2_part - 2 * size_wei_chunk) + / (2 * size_dst_chunk + 2 * size_src_chunk); + // current design of generate() requires ow_block >= 2 * ur_w + int ow_block_cache = ur_w * nstl::max(2, nurw_cache); + + int ow_block_thr = ow_block_cache; + eff = get_thr_eff(nb_oc_blocking, ow_block_thr); + + int max_nb_ow = div_up(jcp.ow, 2 * ur_w); + int start_nb_ow = div_up(jcp.ow, ow_block_thr); + for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); + float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f; + if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + float thr_eff = get_thr_eff(nb_oc_blocking, ow_block); + float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f; + if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) { + ow_block_thr = ow_block; + eff = thr_eff; + } + eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f; + if (eff > eff_threshold) + break; + } + res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr)); + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + return res_ow_block; + }; + + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_oc_blocking = 2; + unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w) + * jcp.ic_block * jcp.kh * jcp.kd; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block + * try_nb_oc_blocking; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_oc_blocking * jcp.kd; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + + bool embd_bcast_condition = true + && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size) + && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192) + && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512); + + if (jcp.mb == 1) { + unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) + * div_up(jcp.iw, jcp.stride_w) * jcp.ic; + unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; + + // Estimate whether we need to limit the number of threads + // and calculate this number. Includes some heuristic. + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; + int job_size_min = work_amount / nthreads; + int job_size_max = div_up(work_amount, nthreads); + int ch_max = rnd_up(jcp.oh, job_size_max); + int ch_min = (job_size_min == 0) + ? jcp.oh + : rnd_up(jcp.oh, job_size_min); + bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 + && (jcp.oh != 8 || ch_max / jcp.oh > 1); + bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 + && (jcp.oh != 8 || ch_min / jcp.oh > 1); + bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) + || nthreads > oc_chunks; + if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1 + && wei_size / inp_size > 24 + && (not_aligned_max || not_aligned_min) + && eligible_case) { + // Try to find nthreads > mkldnn_get_max_threads() / 2 such + // that oc_chunks is a multiple of nthreads, or nthreads is a + // multiple of oc_chunks. Otherwise, keep default value. + // TODO: implement a task-based alternative without throttling. + jcp.aligned_threads = nthreads; + for (int i = nthreads; i > nthreads / 2; i--) { + if (oc_chunks % i == 0 || i % oc_chunks == 0) { + jcp.aligned_threads = i; + break; + } + } + } + } + + if (jcp.kw > 3 + || (jcp.stride_w == 1 && jcp.stride_h == 1 + && embd_bcast_condition) + || ((jcp.stride_w != 1 || jcp.stride_h != 1) + && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10) + && embd_bcast_condition))) + || (jcp.mb == 1 + && (jcp.ur_w >= jcp.ow || jcp.is_1stconv + || (jcp.ow <= 147 && jcp.oc <= 96)))) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.ow, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3 + && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0 + && IMPLICATION(jcp.is_1stconv, jcp.mb == 1) + && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) { + jcp.nb_oc_blocking = try_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_ic_blocking = 1; + if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) { + float best_thr_eff = 0.f; + int best_nb_oc_blocking = 1; + for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) { + if (jcp.nb_oc % i == 0) { + float thr_eff; + int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); + get_ow_block(i, ur_w, thr_eff); + if (thr_eff > 1.05f * best_thr_eff) { + best_nb_oc_blocking = i; + best_thr_eff = thr_eff; + } + } + } + jcp.nb_oc_blocking = best_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } + } + + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.l_pad <= jcp.ur_w + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_ic_L2 = jcp.nb_ic; + + float thr_eff; + jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + const int L2_size = get_cache_size(2, true) / sizeof(float); + // Source and output data needs to fit in L2, + // leaving some space for weights and prefetching. + int h_L2 = int(((0.6f * L2_size) / jcp.simd_w + - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) + / (jcp.stride_h * jcp.iw + jcp.ow)); + jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); + + if (jcp.ver == ver_4fma) { + if (!is_ow_threading_on(jcp)) { + for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; + divf++) { + size_t l2_src + = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; + size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking + * jcp.oh * jcp.od; + size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block + * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.oh == 7) { + jcp.nb_ic_L2 = 1; + break; + } + temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf + : jcp.nb_ic_L2); + } else { + jcp.nb_ic_L2 = temp_nb; + break; + } + } + } else if (jcp.ic > 64) { + jcp.nb_ic_L2 = 2; /* according to performance data*/ + } + } + + return status::success; +} + +void jit_avx512_common_conv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + vpxord(zmm, zmm, zmm); + size_t aux_src_offset + = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) + * jcp.ic_block; + mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) +{ + Label no_update_label; + + mov(reg_channel, ptr[param + GET_OFF(channel)]); + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt)); + } + } + + L(no_update_label); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt), zmm); + mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( + int ur_w, int l_overflow, int r_overflow) +{ + int ow = jcp.ow; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * ow * oc_block; + int ii_dpref_t0 = get_iw_start(0, l_overflow); + int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow); + + bool check_last_kh = (jcp.kh > 3); + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + auto kernel_loads = [=](int ki, int oc, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, oc + ii, ki); + vmovups(zmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) { + int aux_dst_offset = typesize * ((ii_dpref_t0 + + jcp.l_pad) * oc_block + jcp.ow * oc_block); + prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + ii_dpref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + align(16); + L(kh_label); + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1 + && ki == kw - 1 && (oc + 4) == oc_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + + if (ii % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, prf_count_t0 + + oc + 4 - oc_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, oc + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1); + } + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if (ii % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(0, prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + kernel_loads(ki, oc, kk); + + for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if ((ii % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + oc + prf_count_t1, ki))); + prf_count_t1++; + } + if ( ki == 1 && oc == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_dst_prf, aux_dst_offset)); + } + } + } + + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + add(aux_reg_ker_prf, shift_ker_ptr); + sub(aux_reg_dst_prf, shift_dst_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( + int ur_w, int l_overflow, int r_overflow) +{ + Label kh_label, kd_label; + int kw = jcp.kw; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int l_pad = jcp.l_pad; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = oc_block * kw; + int num_inp_prfs = ur_w * nstl::min(kw, stride_w) + + nstl::max(0, kw - stride_w); + int num_prfs = num_ker_loads + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w / stride_w; + int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs); + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + L(kh_label); { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int oc = 0; oc < oc_block; oc++) { + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + int aux_kernel_offset = typesize * ((oc + i) * oc_block + + ki * ic_block * oc_block); + vmovups(zmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + int aux_kernel_offset = typesize * ((oc + load_offset) + * oc_block + ki * ic_block * oc_block); + vmovups(zmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + auto zmm_kernel = zmm_ker(step % ker_pipeline_depth); + + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + assert(stride_w != 1 + || jj_start == nstl::max(0, + l_overflow - (kw - 1 - ki) * dilate_w)); + assert(stride_w != 1 + || jj_end == ur_w - nstl::max(0, + r_overflow - ki * dilate_w)); + + for (int jj = jj_start; jj < jj_end; jj += stride_w) { + assert((jj + l_pad - ki * dilate_w) % stride_w == 0); + int aux_dst_offset = typesize * + (((jj + l_pad - ki * dilate_w) + / stride_w) * jcp.oc_block + oc); + vfmadd231ps(zmm_out(jj, 0), zmm_kernel, + EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true)); + + int fma_idx = (step * ur_w + jj) / stride_w; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (!ker_prf_inserted && ker_prfs < num_ker_loads) { + int ker_prf_offset = typesize + * ker_prfs * jcp.oc_block; + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + int inp_prf_offset + = ic_block * typesize + * ((inp_prf_idx / kw) * kw + + (inp_prf_idx % kw)); + mic_prefetcht0(EVEX_compress_addr( + aux_reg_dst_prf, inp_prf_offset)); + } + } + } + } + step++; + } + } + + add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block); + add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + sub(aux_reg_dst_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + } + + if (jcp.ndims == 5) + { + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( + int ur_w, int l_overflow, int r_overflow) +{ + int kw = jcp.kw; + int ow = jcp.ow; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + Label kh_label, kd_label; + + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block; + + auto output_offset = [=](int oi, int oc, int ki) { + return typesize * + (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc); + }; + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + for (int oc = 0; oc < oc_block; oc++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_output_offset = output_offset(jj, oc, ki); + vbroadcastss(zmm_inp(jj, nb_ic_block), + ptr[aux_reg_dst + aux_output_offset]); + } + } + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, oc, ki); + if (jj_end - jj_start > 0) + vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_out(jj, ii), + zmm_inp(jj, nb_ic_block), zmm_wei); + else + vfmadd231ps(zmm_out(jj, ii), zmm_wei, + EVEX_compress_addr(aux_reg_dst, + output_offset(jj, oc, ki), true)); + } + } + } + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( + int ur_w, int l_overflow, int r_overflow) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + mov(reg_kj, ptr[param + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + + if (jcp.ver == ver_4fma) + compute_loop_4fma(ur_w, l_overflow, r_overflow); + else if (jcp.ver == ver_fma) + if (mayiuse(avx512_mic)) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + compute_loop_fma_core(ur_w, l_overflow, r_overflow); + else + assert("!unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::generate() +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ur_w = jcp.ur_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block; + int src_shift = jcp.typesize_out * ur_w * oc_block; + + preamble(); + + mov(reg_src, ptr[param + GET_OFF(src)]); + mov(reg_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + + mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); + mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]); + mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]); + mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]); + + int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); + int r_overflow = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad)) / stride_w); + int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w); + + int n_oi = iw / ur_w; + if (r_overflow1 > 0) n_oi--; + + if (ur_w == iw) { + compute_loop(ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(ur_w, l_overflow, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + if (ur_w_tail != 0) + compute_loop(ur_w_tail, 0, r_overflow); + } else { + xor_(reg_oi, reg_oi); + if (l_overflow > 0) { + compute_loop(ur_w, l_overflow, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + } + if ((l_overflow <= 0 && n_oi > 0) + || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); { + compute_loop(ur_w, 0, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + compute_loop(ur_w, 0, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + } + if (ur_w_tail != 0) { + compute_loop(ur_w_tail, 0, r_overflow); + } + } + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx512_common)) return status::unimplemented; + + jcp = zero(); + + jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + int ndims = diff_src_d.ndims(); + + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + if ((jcp.dilate_w != 0 && jcp.stride_w != 1) + || (jcp.dilate_d != 0 && jcp.stride_d != 1) + || (jcp.dilate_h != 0 && jcp.stride_h != 1)) + return status::unimplemented; + + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1); + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.aligned_threads = 0; + + jcp.is_1stconv = false; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && diff_src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) + : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i); + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_w = jcp.stride_w; + + int regs = 28; + if (jcp.iw <= regs) + jcp.ur_w = jcp.iw; + else { + for (int ur_w = regs; ur_w > 0; --ur_w) + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - jcp.l_pad) / jcp.stride_w); + int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w); + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) n_oi--; + + if (mayiuse(avx512_common) + && diff_dst_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && diff_src_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops) + && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && jcp.ver != ver_fma) + return status::unimplemented; + + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (jcp.ver == ver_4fma) { + if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) { + jcp.nb_ic_blocking = 2; + } else { + for (int i = jcp.nb_ic; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + } + } + + jcp.loop_order = loop_gnc; + + bool large_code_size = (jcp.ur_w != jcp.ow) + && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1)) + && (r_overflow1 > 0) && (l_overflow > 0); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw; + int mult = 1; + if (l_overflow > 0) mult += 1; + if (r_overflow1 > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2 + < max_code_size) { + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + } + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_ic_blocking = 2; + unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block + * try_nb_ic_blocking * jcp.kh; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_ic_blocking; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) + || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13)) + || ker_total_size > L1_cache_size ))) + || jcp.stride_h > 1 || jcp.stride_d > 1) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.iw, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size + && jcp.ow > 8)) && jcp.stride_h == 1) + if (jcp.nb_ic % try_nb_ic_blocking == 0) { + jcp.nb_ic_blocking = try_nb_ic_blocking; + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_oc_blocking = 1; + jcp.nb_ic_blocking = 4; + if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic; + if (jcp.nb_ic % jcp.nb_ic_blocking != 0) + for (int i = jcp.nb_ic_blocking; i > 0; i--) + if (jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + if (l_overflow * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_oc_L2 = jcp.nb_oc; + if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { + for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; + divf++) { + size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih + * jcp.id; + size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; + size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh + * jcp.kd * jcp.nb_ic_blocking * temp_nb; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.ih == 7) { + jcp.nb_oc_L2 = 1; + break; + } + temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf + : jcp.nb_oc_L2); + } else { + jcp.nb_oc_L2 = temp_nb; + break; + } + } + } + + args_ok = true + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + return status::success; +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28; + +void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_label; + + /* 'depth' loop count bound by 'kd_work_size' */ + mov(kj, reg_kd_count); + L(kd_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + Label kh_comeback_label, kd_comeback_label; + mov(kj, reg_kh); + L(kh_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(Zmm(i_kw * ic_block_step + i_ic), + EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block + + i_ic) * jcp.oc_block + kernel_offset)); + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + if (i_ur == 0) { + vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 0) + * oc_block + output_offset)); + if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block + + output_offset)); + if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block + + output_offset)); + if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + } else if (i_ur + 3 < ur_w) + vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); + if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + + (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + const size_t i_offset = (size_t)input_offset + + (size_t)typesize * (jcp.ver == ver_4fma + ? (i_iw - pad_l + i_ic * jcp.tr_iw) + : (jcp.is_1stconv + ? (i_iw - pad_l) + (size_t)i_ic + * ((size_t)jcp.ih*jcp.iw*jcp.id) + : (i_iw - pad_l) * ic_block + i_ic)); + vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), + Zmm(kw * ic_block_step + i_ur % 4), + EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, + true)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(EVEX_compress_addr(reg_kernel, typesize + * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset), + Zmm(i_kw * ic_block_step + i_ic)); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + // TODO: add prefetches to fma version as well + + assert(jcp.ver == ver_4fma); + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + auto zmm_ker = [=](int i_kw, int i_ic) { + return Zmm(i_kw * ic_block_step + i_ic); + }; + + auto ker_addr = [=](int i_kw, int i_ic) { + size_t local_offset + = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block; + return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); + }; + + auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) { + int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1); + int local_offset = jcp.typesize_in * (i_iw + i_ic * stride); + return EVEX_compress_addr(reg_input, + local_offset + input_offset + extra_offset); + }; + + auto zmm_out = [=](int i_iw) { + // TODO: move reg calc to global member funcs + const int out_zmm_base_idx = 28; + return Zmm(out_zmm_base_idx + i_iw % 4); + }; + + auto out_addr = [=](int i_ur) { + return EVEX_compress_addr(reg_output, + jcp.typesize_in * i_ur * oc_block + output_offset); + }; + + auto pf_callback = [=](int i_ur, int i_kw, int i_ic) { + assert(i_ur % 4 == 0); + if (i_ur == 0) + prefetcht1(ker_addr(i_kw, i_ic)); + if (i_ur + 4 >= ur_w) + prefetcht0(ker_addr(i_kw, i_ic)); + + const ptrdiff_t next_input_block_offset + = jcp.typesize_in * ic_block_step * jcp.tr_iw; + if (i_ur % 16 == 4 && i_kw == 0) { + if (i_ur + 16 < ur_w) + prefetcht0(inp_addr(i_ur + 16, i_ic)); + else + prefetcht0(inp_addr(0, i_ic, next_input_block_offset)); + } + if (i_ur % 16 == 4 && i_kw == 1) { + if (input_wraparound) + prefetcht1(inp_addr(i_ur, i_ic, -input_offset)); + else + prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset)); + } + }; + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto zmm = zmm_ker(i_kw, i_ic); + vpxord(zmm, zmm, zmm); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur += 4) { + + for (int i = 0; i < 4; i++) { + auto zmm = zmm_out(i_ur + i); + if (i_ur + i < ur_w) + vmovups(zmm, out_addr(i_ur + i)); + else + vpxord(zmm, zmm, zmm); + prefetcht0(out_addr(i_ur + i + 4)); + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + int i_iw = i_ur + i_kw; + v4fmaddps(zmm_ker(i_kw, i_ic), + zmm_out(i_ur), inp_addr(i_iw, i_ic)); + pf_callback(i_ur, i_kw, i_ic); + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto addr = ker_addr(i_kw, i_ic); + auto zmm = zmm_ker(i_kw, i_ic); + vaddps(zmm, zmm, addr); + vmovups(addr, zmm); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + if (jcp.ver == ver_4fma) + compute_ic_block_step_4fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else if (jcp.ver == ver_fma) + compute_ic_block_step_fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else + assert(!"unknown convolution version"); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow_icblock( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + Label kh_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { + const int input_offset = jcp.typesize_in + * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic); + compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step, + input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, + i_b_ic + ic_block_step >= jcp.ic_block); + } + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, kd_label; + + UNUSED(max_ur_w); + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + + int r_pad = nstl::max(0, + (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + xor_(b_ic, b_ic); + L(ic_block_label); { + compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, + 0, 0, 0); + size_t inp_icblk_stride = jcp.is_1stconv + ? (size_t)jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, ow_block_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad; + + int ur_w = nstl::min(ow, max_ur_w); + int ur_w_trips = ow / ur_w; + int ur_w_tail = ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) + || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + + int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block; + int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); { + xor_(b_ic, b_ic); + L(ic_block_label); { + if (l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad) + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + L(ow_block_label); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_label, T_NEAR); + } + } + + if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad, + ic_block_step, 0, 0, 0); + + sub(reg_input, jcp.typesize_in * input_comeback); + sub(reg_output, jcp.typesize_in * output_comeback); + int inp_icblk_stride = jcp.is_1stconv + ? jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_disp() +{ + int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2); + if (jcp.is_1stconv) { + bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0); + ic_block_step + = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1; + } + + bool too_large_to_unroll + = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) + && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); + + int ow = jcp.ow; + if (jcp.ndims == 5) { + /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of + * 'movs' must be guaranteed. */ + mov(ki, reg_kd_count); + push(reg_kd_count); + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) + compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w); + else if (ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + pop(reg_kd_count); + od_step_comeback_pointers(); + } else { + oh_step_comeback_pointers(); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() +{ + Label skip_zeroing, zeroing_loop; + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jz(skip_zeroing, T_NEAR); + + Zmm zero = Zmm(0); + vpxord(zero, zero, zero); + xor_(reg_tmp, reg_tmp); + L(zeroing_loop); { + assert(jcp.oc_block * jcp.typesize_out + == cpu_isa_traits::vlen); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block + * jcp.typesize_out], zero); + add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd + * jcp.typesize_out); + jnz(zeroing_loop); + } + + L(skip_zeroing); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() +{ + Label skip_bias, bias_loop, skip_load_bias; + + mov(reg_tmp, ptr[param + GET_OFF(flags)]); + test(reg_tmp,reg_tmp); + jne(skip_bias, T_NEAR); + + mov(reg_bias, ptr[param + GET_OFF(bias)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + vpxord(Zmm(1), Zmm(1), Zmm(1)); + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jne(skip_load_bias, T_NEAR); + vmovups(Zmm(1), ptr[reg_bias]); + + L(skip_load_bias); + + mov(reg_oi, ptr[param + GET_OFF(d_worksize)]); + sub(reg_oi, ptr[param + GET_OFF(d_index)]); + mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out); + imul(reg_oi, reg_tmp); + + xor_(reg_tmp, reg_tmp); + L(bias_loop); { + vmovups(Zmm(0), ptr[reg_output + reg_tmp]); + vaddps(Zmm(1), Zmm(1), Zmm(0)); + add(reg_tmp, jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, reg_oi); + jl(bias_loop); + } + vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1)); + + L(skip_bias); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_loop_common() +{ + int b_pad = jcp.b_pad; + int t_pad = jcp.t_pad; + bool is_dilated = jcp.dilate_h != 0; + int dilate_h = jcp.dilate_h + 1; + int stride_h = jcp.stride_h; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, + oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, + oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; + + int ow = jcp.ow; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + /* Compute 'top' edge */ + if (t_pad > 0) { + const int kh_range = 1 + (jcp.kh - 1) * dilate_h; + const int overflow + = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); + const int underflow = div_up(t_pad, dilate_h); + const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; + mov(reg_kh, initial_inp_ker_overlap); + add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block + * jcp.oc_block); + // generate loop to process kernel while it remains within t_pad + ih + if (kh_range < t_pad + jcp.ih) { + if (is_dilated) { + const int tail = t_pad % dilate_h; + const int shift = tail == 0 ? 0 : dilate_h - tail; + mov(reg_tmp, shift); + if (tail != 0) + add(reg_input, jcp.typesize_in * shift * iw * inp_mult); + } + L(oh_tpad_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_shift, T_NEAR); + // unshift input as new kernel element enters + sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); + xor_(reg_tmp, reg_tmp); + } + // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_kh, stride_h); + if (is_dilated) { + jmp(oh_dilate_label_noshift, T_NEAR); + L(oh_dilate_label_shift); + // shift input as old kernel element progresses + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + L(oh_dilate_label_noshift); + } + inc(reg_oj); + add(reg_ih_count, stride_h); + + // final number of kernel elements that overlap with input + const int final_inp_ker_overlap + = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_label, T_NEAR); + } + } + // need second loop to process kernel if it is larger than the input + // (does not apply to dilations as they must have unit stride) + if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : + t_pad % stride_h)) { + assert(!is_dilated); + mov(reg_kh, jcp.ih); + L(oh_tpad_tail_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h)); + jl(oh_tpad_tail_label, T_NEAR); + } + } + // correct any excess shifts to kernel and input + // (does not apply to dilations as they must have unit stride, + // kernel must fit inside input, and padding is smaller than input) + if (t_pad <= jcp.oh * stride_h) { + // kernel has moved beyond padding (adjust for stride effects) + if (t_pad % stride_h != 0) { + assert(!is_dilated); + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult); + } + } else { + // kernel still overlaps padding (complete reset) + assert(!is_dilated); + sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) + * jcp.kw * jcp.ic_block * jcp.oc_block); + } + } + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_label, T_NEAR); + + /* Compute middle block(s) */ + mov(reg_kh, jcp.kh); + L(oh_label); { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_label, T_NEAR); + } + L(oh_label_end); + + /* Compute bottom edge */ + if (b_pad > 0) { + cmp(reg_oj, jcp.oh); + jge(oh_bpad_label_end, T_NEAR); + + if (is_dilated) { + mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations + mov(reg_tmp, 0); + } else { + mov(reg_kh, jcp.ihp - b_pad); + sub(reg_kh, reg_ih_count); + } + L(oh_bpad_label); + { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_end, T_NEAR); + xor_(reg_tmp, reg_tmp); + } + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_label_end, T_NEAR); + if (is_dilated) + L(oh_dilate_label_end); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_label, T_NEAR); + } + L(oh_bpad_label_end); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() { + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + const int input_backpad_overlap + = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); + + const size_t filter_shift + = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block; + const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult; + const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block; + + Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, + backpad_end_label, backpad_label; + + if (jcp.with_bias) bias_kernel(); + + /* initially offset 'kd' by f_pad */ + add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); + + mov(reg_input_d, ptr[param + GET_OFF(src)]); + mov(reg_output_d, ptr[param + GET_OFF(dst)]); + mov(reg_d_index, ptr[param + GET_OFF(d_index)]); + mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); + + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jge(loop_end_label, T_NEAR); + + L(d_loop_label); + + mov(reg_input, reg_input_d); + mov(reg_output, reg_output_d); + + push(reg_input_d); + push(reg_output_d); + push(reg_d_index); + + compute_oh_loop_common(); + + pop(reg_d_index); + pop(reg_output_d); + pop(reg_input_d); + + /* Compute 'front' edge */ + if (jcp.f_pad > 0) { + + /* Check if within fpad region */ + cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); + jge(fpad_end_label, T_NEAR); + + /* Fpad steps */ + sub(reg_kernel, filter_shift * jcp.stride_d); + add(reg_kd_count, jcp.stride_d); + + /* Final number of kernel elements that overlap with input */ + const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id); + cmp(reg_kd_count, inp_ker_overlap); + jl(common_block_label, T_NEAR); + + /* Correct any excess shifts to kernel and input */ + if (jcp.f_pad <= jcp.od * jcp.stride_d) { + /* Filter has moved beyond padding (adjust for stride effects) */ + if (jcp.f_pad % jcp.stride_d != 0) { + int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; + add(reg_kernel, filter_shift * inp_corr); + add(reg_input_d, input_shift * inp_corr); + } + } else { + /* Filter still overlaps padding (complete reset) */ + sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); + } + + /* Apply correction */ + mov(reg_kd_count, jcp.kd); + jmp(common_block_label); + + L(fpad_end_label); + } + + /* Compute bottom edge */ + if (jcp.back_pad > 0) { + + /* Check if within back_pad region */ + cmp(reg_d_index, input_backpad_overlap - 1); + jl(backpad_end_label, T_NEAR); + jg(backpad_label, T_NEAR); + + /* Execute overlap correction between the filter and the initial + * back_pad region. */ + mov(reg_kd_count, + jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d); + jmp(backpad_end_label, T_NEAR); + + L(backpad_label); + sub(reg_kd_count, jcp.stride_d); + cmp(reg_kd_count, 0); + jle(loop_end_label, T_NEAR); + + L(backpad_end_label); + } + + /* Compute middle block */ + add(reg_input_d, input_shift * jcp.stride_d); + + /* Execute common block and loop */ + L(common_block_label); + add(reg_output_d, output_shift); + inc(reg_d_index); + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jl(d_loop_label, T_NEAR); + + L(loop_end_label); +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() { + // FIXME: use register mapping from the class declaration + bool ok = jcp.ver == ver_4fma + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && everyone_is(1, jcp.stride_h, jcp.stride_w); + if (!ok) return false; + if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) + return false; + + // General code layout: + // + // Blocking over OH -- top level + // (Reduces L2 pressure; not very useful right now) + // Loop over all KHxKW kernel -- emit_kh_kw_loop() + // Loop over OH block -- emit_h_loop() + // Loop over OW blocks -- emit_fma_block() + // (Supports both fully unrolled and partially unrolled versions to + // reduce code size) + // Loop over OW block -- emit_fma_step() + + int max_working_set_size = 128 * 1024; + int pad_ow = jcp.ow; + + int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; + int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in; + int row_size = inp_row_size + out_row_size; + + int h_block_size = jcp.oh; + int working_set_size = row_size * h_block_size; + + if (working_set_size > max_working_set_size) { + int opt_working_set_size = 48 * 1024; + assert(opt_working_set_size < max_working_set_size); + + while (working_set_size > opt_working_set_size) { + for (int i = 2; i <= h_block_size; i++) + if (i == h_block_size) + h_block_size = h_block_size / 2; + else if (h_block_size % i == 0) { + h_block_size = h_block_size / i; + break; + } + working_set_size = row_size * h_block_size; + + if (h_block_size == 1 && working_set_size > opt_working_set_size) + return false; + } + } + + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below) + if (h_block_size < nstl::max(1, jcp.t_pad) + || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size + : jcp.oh % h_block_size)) + return false; + + // check that we can use simple arithmetic for prefetch address + // calculations + // TODO: we need some traits for this check (Roma) + int cache_line_size = 64; + assert(jcp.ic_block * typesize == 64); + assert(jcp.oc_block * typesize == 64); + + int num_inp_l2_pfs = jcp.tr_iw * h_block_size; + int avg_h_loop_len = h_block_size; + int num_inp_l2_pfs_per_fma_block + = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + int num_out_l2_pfs = pad_ow * h_block_size; + int num_out_l2_pfs_per_fma_block + = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + + Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors + Reg64 reg_kh = rax; + Reg64 reg_kw = rbx; + Reg64 reg_tmp = abi_not_param1; + Reg32 reg_tmp_w = reg_tmp.cvt32(); + Reg64 reg_ohs = rdx; + Reg64 reg_ihs = rsi; + Reg64 reg_h = r8; + Reg64 reg_i = r9; + Reg64 reg_j = r10; + + Reg64 reg_inp = r13; + Reg64 reg_out = r14; + Reg64 reg_ker = r15; + + Reg64 reg_inp_pf_l1 = rbp; + + Reg64 reg_inp_pf_l2 = r11; + Reg64 reg_out_pf_l2 = r12; + + Xmm reg_inp_pf_save = xmm17; + Xmm reg_out_pf_save = xmm18; + + Reg64 reg_inp_save = abi_param1; + Reg64 reg_out_save = reg_tmp; + + auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); }; + auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; + auto inp_addr = [&](int oi, int ic1) { + return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in]; + }; + auto out_addr = [&](int oi, int oj = 0) { + assert(jcp.ver == ver_4fma); + return ptr[reg_out + + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in]; + }; + auto ker_addr = [&](int ic1) { + return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out]; + }; + + auto emit_block = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row) + { + // TODO: add an fma version (Roma) + auto pad_ow = jcp.ow; + + int ow4u = rnd_up(pad_ow, 4); + int def_step_size = 16; + + bool has_w_tail = (pad_ow % def_step_size != 0 + || pad_ow % 4 != 0); + bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; + + auto emit_step = [&](int ur_ow, + int num_inp_l1_pfs_per_fma_step, + int num_inp_l2_pfs_per_fma_step, + int num_out_l2_pfs_per_fma_step, bool is_w_tail) + { + bool block_wraparound = is_w_tail && is_last_row; + + assert(ur_ow % 4 == 0); + int tail_size = ow4u % ur_ow; + int this_ur_ow + = (is_w_tail && tail_size) ? tail_size : ur_ow; + int ow_last_chunk4 = pad_ow % 4; + int ow_zero_tail4 = ow_last_chunk4 + ? 4 - ow_last_chunk4 : 0; + + auto emit_out_pf = [&](int oi) { +#if 1 + if (oi + def_step_size < ur_ow || !block_wraparound) + mic_prefetcht0(ptr[reg_out + + ((def_step_size + oi) + * jcp.oc_block * jcp.typesize_in)]); + else { + assert(block_wraparound); + assert(oi + def_step_size >= ur_ow); + mic_prefetcht0(ptr[reg_out_save + + ((oi + def_step_size - ur_ow) + * jcp.oc_block * jcp.typesize_in)]); + } +#else + // XXX: This is an alternative prefetching strategy that + // always prefetches the next row. Keeping it here for + // future experiments (Roma) + if (!block_wraparound) + mic_prefetcht0(ptr[reg_out + + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]); + else + mic_prefetcht0(ptr[reg_out + reg_ohs + - ((h_block_size - 1) * jcp.ow + - oi) * jcp.oc_block * jcp.typesize_in]); +#endif + if (oi < num_out_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_out_pf_l2 + + oi * jcp.oc_block * jcp.typesize_in]); + }; + + auto emit_inp_pf = [&](int oi4, int ic1) { + int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block; + int num_pf_slots = jcp.ic_block * ur_ow / 4; + + int num_pfs = num_inp_l1_pfs_per_fma_step + + num_inp_l2_pfs_per_fma_step; + int pf_freq = nstl::max(1, num_pf_slots / num_pfs); + + if (pf_slot_idx % pf_freq) + return; + + int pf_idx = pf_slot_idx / pf_freq; + + if (pf_idx < num_inp_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_inp_pf_l2 + + pf_idx * jcp.ic_block * jcp.typesize_in]); + else { + pf_idx -= num_inp_l2_pfs_per_fma_step; + // prefetch the 'tail' of the cache line because most of + // the accesses are not aligned + mic_prefetcht0(ptr[reg_inp_pf_l1 + + pf_idx * jcp.ic_block * jcp.typesize_in + + cache_line_size - jcp.typesize_in]); + } + }; + + auto numloads = 4; + + int steps = this_ur_ow; + for (int oi4 = 0; oi4 < steps; oi4 += numloads) { + for (int oi1 = 0; oi1 < numloads; oi1++) { + int oi = oi4 + oi1; + if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) { + vmovups(zmm_out(oi), out_addr(oi)); + emit_out_pf(oi); + } else { + auto zmm = zmm_out(oi); + vpxord(zmm, zmm, zmm); + } + } + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + if (jcp.ver == ver_4fma) { + v4fmaddps(zmm_ker(ic1), + zmm_out(oi4), inp_addr(oi4, ic1)); + } else { + assert(!"unknown convolution version"); + } + emit_inp_pf(oi4, ic1); + } + } + }; + + // Input is transposed and padded but we only access about jcp.iw + // elements so use that to compute the # of cache lines in each 'row' + int num_inp_l1_pfs + = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block; + + if (full_w_unroll) { + emit_step(ow4u, num_inp_l1_pfs, + num_inp_l2_pfs_per_fma_block, + num_out_l2_pfs_per_fma_block, true); + add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size); + add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size); + } else { + Label w_loop; + int num_w_iters = pad_ow / def_step_size; + int num_w_iters_full = num_w_iters + has_w_tail; + int num_inp_l1_pfs_per_fma_step + = div_up(num_inp_l1_pfs, num_w_iters_full); + int num_inp_l2_pfs_per_fma_step + = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full); + int num_out_l2_pfs_per_fma_step + = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full); + mov(reg_i, num_w_iters); + L(w_loop); { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, false); + add(reg_inp, def_step_size * jcp.typesize_in); + add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in); + add(reg_inp_pf_l1, + num_inp_l1_pfs_per_fma_step * cache_line_size); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + sub(reg_i, 1); + jnz(w_loop); + } + if (has_w_tail) { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, true); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + } + // reset reg_inp and reg_out because emit_h_loop expects + // unmodified pointers + int w_offset = num_w_iters * def_step_size; + sub(reg_inp, w_offset * jcp.typesize_in); + sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in); + } + }; + + auto emit_h_loop = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter) + { + Label h_loop, skip_h_loop; + mov(reg_j, 1); + cmp(reg_j, reg_h); + je(skip_h_loop, T_NEAR); + L(h_loop); { + + lea(reg_inp_pf_l1, + ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]); + emit_block(h_block_size, + is_last_block, is_last_kh_kw_iter, false); + + add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in); + add(reg_j, 1); + cmp(reg_j, reg_h); + jb(h_loop); + } + + L(skip_h_loop); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + mic_prefetcht0(ker_addr(ic1)); + + lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]); + emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true); + }; + + auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block, + int h_block_size) + { + xor_(reg_kh, reg_kh); + Label kh_loop, kh_loop_end; + + int last_oh_block_size + = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size); + int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size; + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size + int ih_block_size = oh_block_size - 1 + jcp.kh + - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad; + + L(kh_loop); { + // determine starting indices for this block + if (is_first_block) { + xor_(reg_tmp, reg_tmp); + mov(reg_ohs, jcp.t_pad); + sub(reg_ohs, reg_kh); + cmovb(reg_ohs, reg_tmp); + + mov(reg_ihs, reg_ohs); + sub(reg_ihs, jcp.t_pad); + add(reg_ihs, reg_kh); + } else { + xor_(reg_ohs, reg_ohs); + mov(reg_ihs, reg_kh); + } + + // determine effective size of block based on padding + mov(reg_tmp, oh_block_size); + sub(reg_tmp, reg_ohs); + mov(reg_h, ih_block_size); + sub(reg_h, reg_ihs); + cmp(reg_tmp, reg_h); + cmovb(reg_h, reg_tmp); + + Label kh_loop_work; + cmp(reg_h, 0); + jg(kh_loop_work, T_NEAR); + + // empty h loop for this jcp.kh: + // - set the output to 0 if necessary + // - move ker pt + // - jump to the end + sub(reg_h, 1); + Label skip_ker_zeroing; + + // The reg_ker ptr has highest bit set if the output needs to be + // zeroed. Those who have byte-aligned their data will suffer the + // consiquences :( + // TODO: move the flag to a mask register? (Roma) + test(reg_ker, 1); + jz(skip_ker_zeroing, T_NEAR); + + Label zeroing_loop; + vpxord(zmm0, zmm0, zmm0); + and_(reg_ker, ~1); // temporarily clear the zeroing flag + mov(reg_tmp, jcp.kw); + L(zeroing_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ker_addr(ic1), zmm0); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out); + sub(reg_tmp, 1); + jnz(zeroing_loop, T_NEAR); + } + // restore the zeroing flag (it will be cleared after the end of + // emit_kh_kw_loop, but we may need it until then) + or_(reg_ker, 1); + jmp(kh_loop_end, T_NEAR); + + L(skip_ker_zeroing); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw + * jcp.typesize_out); + jmp(kh_loop_end, T_NEAR); + + L(kh_loop_work); + + mul_by_const(reg_ihs, reg_tmp, + jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + mul_by_const(reg_ohs, reg_tmp, + pad_ow * jcp.oc_block * jcp.typesize_in); + + add(reg_inp, reg_ihs); + add(reg_out, reg_ohs); + + Label kw_loop; + xor_(reg_kw, reg_kw); + L(kw_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vpxord(zmm, zmm, zmm); + mic_prefetcht1(ker_addr(ic1)); + } + + mov(reg_out_save, reg_out); + mov(reg_inp_save, reg_inp); + lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]); + +#if 0 + // XXX: Generate code with special prefetches when switching + // blocks or at the end of the last block. Disabled to reduce + // code size and because there's no performance benefit (Roma) + Label regular_h_loop, end_h_loop; + cmp(reg_kw, jcp.kw - 1); + jne(regular_h_loop, T_NEAR); + cmp(reg_kh, jcp.kh - 1); + jne(regular_h_loop, T_NEAR); + + emit_h_loop(oh_block_size, is_last_block, true); + jmp(end_h_loop, T_NEAR); + + L(regular_h_loop); + emit_h_loop(oh_block_size, is_last_block, false); + + L(end_h_loop); +#else + emit_h_loop(oh_block_size, is_last_block, false); +#endif + + mov(reg_out, reg_out_save); + mov(reg_inp, reg_inp_save); + + Label do_store; + // The reg_ker ptr has highest bit set if the output needs to + // be zeroed. Those who have byte-aligned their data will + // suffer the consiquences :( + mov(reg_tmp, reg_ker); + and_(reg_ker, ~1); + test(reg_tmp, 1); + jnz(do_store, T_NEAR); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + if (jcp.ver == ver_4fma) { + vaddps(zmm, ker_addr(ic1)); + } else { + assert(!"unknown convolution version"); + } + } + + L(do_store); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vmovups(ker_addr(ic1), zmm); + } + + mov(reg_ker, reg_tmp); + add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + add(reg_kw, 1); + cmp(reg_kw, jcp.kw); + jl(kw_loop); + } + + sub(reg_inp, reg_ihs); + sub(reg_out, reg_ohs); + + + L(kh_loop_end); + add(reg_kh, 1); + cmp(reg_kh, jcp.kh); + jl(kh_loop); + } + }; + + mov(reg_inp, ptr[param + GET_OFF(src)]); + mov(reg_out, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]); + mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]); + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + or_(reg_ker, reg_tmp); + + bool single_kh_kw_loop = (h_block_size == jcp.oh); + + size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in; + size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad); + size_t inp_block_step = inp_row_step * h_block_size; + size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in + * h_block_size; + + if (!single_kh_kw_loop) { + // Save the original prefetch pointers from the OpenMP driver + vmovq(reg_inp_pf_save, reg_inp_pf_l2); + vmovq(reg_out_pf_save, reg_out_pf_l2); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, first_inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + } + emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size); + + if (!single_kh_kw_loop) { + size_t ker_reset_offset + = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh; + sub(reg_ker, ker_reset_offset); + and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates + + add(reg_inp, first_inp_block_step); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + + int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2; + if (num_innermost_iters > 0) { + Label h_block_loop; + + mov(reg_tmp_w, num_innermost_iters); + kmovw(reg_h_block, reg_tmp_w); + L(h_block_loop); { + emit_kh_kw_loop(false, false, h_block_size); + sub(reg_ker, ker_reset_offset); + add(reg_inp, inp_row_step * h_block_size); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + kmovw(reg_tmp_w, reg_h_block); + sub(reg_tmp_w, 1); + kmovw(reg_h_block, reg_tmp_w); + jnz(h_block_loop); + } + } + + // Restore the original prefetch pointers that came from the OpenMP + // driver + vmovq(reg_inp_pf_l2, reg_inp_pf_save); + vmovq(reg_out_pf_l2, reg_out_pf_save); + emit_kh_kw_loop(false, true, h_block_size); + } + + return true; +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32 + ::flat_4ops_compute() { + const auto &j = jcp; + const bool ok = j.ver == ver_4fma && j.is_1stconv + && everyone_is(0, j.dilate_h, j.dilate_w); + if (!ok) return false; + + Reg64 reg_ptr_tr_src = r8; + Reg64 reg_ptr_dst = r9; + Reg64 reg_ptr_wei = r10; + Reg64 reg_ptr_bia = r11; + + Reg64 reg_kh_step = rax; + Reg64 reg_oh = abi_not_param1; + Reg64 reg_kh = rdx; + + Reg32 reg_flag_save = ebx; + Reg32 reg_flag = esi; + + Zmm vbia(31); + + auto zmm_wei = [&](int kh, int kw) { + return Zmm(8 + kh * j.kw + kw); + }; + auto zmm_dst = [&](int ow) { + return Zmm(ow % 8); + }; + + auto addr_tr_src = [&](int kh, int iw) { + return ptr[reg_ptr_tr_src + + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in]; + }; + auto addr_dst = [&](int ow) { + return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in]; + }; + auto addr_wei = [&](int kh, int kw) { + return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block + * jcp.typesize_out]; + }; + + auto emit_fma_block = [&](int kh_step) { + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + auto vwei = zmm_wei(kh, kw); + vpxord(vwei, vwei, vwei); + } + } + + for (int ow = 0; ow < j.ow; ow += 4) { + for (int _ow = ow; _ow < ow + 4; ++_ow) { + auto vdst = zmm_dst(_ow); + if (_ow < j.ow) + vmovups(vdst, addr_dst(_ow)); + else + vpxord(vdst, vdst, vdst); + } + + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + const int iw = ow + (kw % j.stride_w) * j.tr_ld + + (kw / j.stride_w); + v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow), + addr_tr_src(kh, iw)); + if (1 && kh == 0 && kw < 4) { + prefetcht1(ptr[reg_ptr_dst + + (j.ow + ow + kw) * jcp.oc_block + * jcp.typesize_in]); + } + if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */ + const int off = kw + 4 - j.kw; + if (off >= 0 && ow + off < j.ow) + vaddps(vbia, vbia, zmm_dst(ow + off)); + } + } + } + } + + Label l_store; + test(reg_flag, FLAG_MB_FIRST); + jnz(l_store, T_NEAR); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vaddps(zmm_wei(kh, kw), addr_wei(kh, kw)); + } + L(l_store); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vmovups(addr_wei(kh, kw), zmm_wei(kh, kw)); + } + }; + + auto emit_kh_loop = [&]() { + const int kh_step_rem = j.kh % j.kh_step; + xor_(reg_kh, reg_kh); + mov(reg_kh_step, j.kh_step); + + Label l_kh_loop; + L(l_kh_loop); { + Label l_done; + + if (kh_step_rem != 0) { + Label l_keep_kh_step; + cmp(reg_kh, j.kh - j.kh_step); + jle(l_keep_kh_step, T_NEAR); + + mov(reg_kh_step, kh_step_rem); + emit_fma_block(kh_step_rem); + jmp(l_done, T_NEAR); + + L(l_keep_kh_step); + } + + emit_fma_block(j.kh_step); + + L(l_done); + + add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out); + add(reg_kh, j.kh_step); + + cmp(reg_kh, j.kh); + jl(l_kh_loop, T_NEAR); + } + + const int kh_steps = rnd_up(j.kh, j.kh_step); + sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in); + sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out); + }; + + auto emit_oh_loop = [&]() { + mov(reg_oh, j.oh); + + Label l_oh_loop; + L(l_oh_loop); { + Label l_restore_mb_flag, l_jump; + + cmp(reg_oh, j.oh); + je(l_restore_mb_flag, T_NEAR); + + and_(reg_flag, ~FLAG_MB_FIRST); + jmp(l_jump, T_NEAR); + + L(l_restore_mb_flag); + mov(reg_flag, reg_flag_save); + + L(l_jump); + + emit_kh_loop(); + + add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in); + + dec(reg_oh); + jnz(l_oh_loop, T_NEAR); + } + }; + + auto emit_bia_store = [&]() { + if (!j.with_bias) return; + + Label l_bia_store, l_bia_skip; + test(reg_flag, FLAG_IC_FIRST); + jz(l_bia_skip); + + test(reg_flag, FLAG_MB_FIRST); + jnz(l_bia_store, T_NEAR); + vaddps(vbia, ptr[reg_ptr_bia]); + L(l_bia_store); + vmovups(ptr[reg_ptr_bia], vbia); + L(l_bia_skip); + }; + + mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]); + mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]); + mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]); + mov(reg_flag_save, ptr[param + GET_OFF(flags)]); + + vpxord(vbia, vbia, vbia); + emit_oh_loop(); + emit_bia_store(); + + return true; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() +{ + if (flat_4ops_compute()) + return; + if (compute_full_spat_loop()) + return; + + maybe_zero_kernel(); + + if (jcp.ndims == 5) compute_d_loop_common(); + else compute_oh_loop_common(); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::generate() +{ + preamble(); + + mov(reg_input, ptr[param + GET_OFF(src)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + mov(reg_kernel, ptr[param + GET_OFF(filt)]); + + compute_loop(); + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) { + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper diff_weights_d(&diff_weights_md); + const memory_desc_wrapper diff_bias_d(&diff_bias_md); + const memory_desc_wrapper diff_dst_d(&diff_dst_md); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + + jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); + bool ok = true + // general condition to simplify dilations + && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) + && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) + && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) + // special condition to simplify dilations in compute_oh_loop_common + && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih); + if (!ok) + return status::unimplemented; + + jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); + jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); + + /* XXX: currently, does not support dilation_d > 0 */ + if (ndims == 5) + if (jcp.dilate_d > 0) + return status::unimplemented; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.aligned_threads = 0; + + /* check for the 1st convolution */ + jcp.is_1stconv = is_1stconv(jcp); + + jcp.oc_block = jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) + jcp.oc = rnd_up(jcp.oc, jcp.simd_w); + + if (jcp.oc % jcp.oc_block) + return status::unimplemented; + + auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + + if (diff_dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + /* conditions on bias memory */ + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (diff_bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md, x)); + } + + jcp.nb_oc = jcp.oc / jcp.oc_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; + const bool boundaries_ok = true + && jcp.t_pad <= max_pad + && jcp.b_pad <= max_pad + && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad) + && jcp.f_pad < jcp.kd; + if (!boundaries_ok) + return status::unimplemented; + + /* yet another common check */ + if (jcp.kw > 14) + return status::unimplemented; + + /* setting register strategy */ + for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; } + } + + if (jcp.is_1stconv) { + auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + if (jcp.ic == 1 && jcp.src_tag != src_tag) + jcp.src_tag = src_d.matches_one_of_tag( + pick(ndims - 3, nwc, nhwc, ndhwc)); + } + if (jcp.src_tag == format_tag::undef) + return status::unimplemented; + + const bool src_ok = true + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type()) + && one_of(jcp.ic, 1, 2, 3) + && jcp.ngroups == 1; + if (!src_ok) + return status::unimplemented; + + const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, + jcp.stride_w), 16); + const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1); + const int kh_step_rem = jcp.kh % kh_step; + + const auto wei_4fma_tag = with_groups + ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o); + + auto current_wei_tag = format_tag::undef; + if (diff_weights_d.format_kind() != format_kind::any) + current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag); + + const bool use_4fma = true + && one_of(ndims, 3, 4) + && mayiuse(avx512_mic_4ops) + && mkldnn_thr_syncable() + && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) + && jcp.kw <= 28 - jcp.with_bias + && jcp.stride_w == 4 + && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */ + && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */ + && IMPLICATION(diff_weights_d.format_kind() != format_kind::any, + current_wei_tag == wei_4fma_tag); + + if (use_4fma) { + jcp.ver = ver_4fma; + jcp.kh_step = kh_step; + jcp.tr_ld = tr_ld; + jcp.ic_block = 1; + if (diff_weights_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag)); + jcp.wei_tag = wei_4fma_tag; + } else { + jcp.ver = ver_fma; + jcp.ic_block = jcp.ic; + + wei_tag = with_groups + ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + jcp.nb_ic = jcp.ic / jcp.ic_block; + } else { + auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + jcp.ic_block = jcp.simd_w; + if (ok_to_pad_channels) + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + if ((mayiuse(avx512_mic) || mayiuse(avx512_core)) + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type())) { + jcp.ver = ver_fma; + if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && + everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && + mkldnn_thr_syncable()) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + if (jcp.ver == ver_4fma) { + jcp.ur_w = jcp.ow; + // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to + // cross the right boundary. The only requirement is not to have + // NaNs there because another multiplicand is always guaranteed to + // be zero. This also may require the top-level driver to allocate + // four extra guarding elements at the very end of the buffer. + // I'm not proud of this hack, but it improves performance by + // about 5-10% depending on the dimensions (Roma) + + const int tr_round = 4; + + jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round); + jcp.tr_src_num_guard_elems = tr_round; // upper bound + } + } + + if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) { + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + } else + return status::unimplemented; + + bool args_ok = true + && jcp.ic % jcp.ic_block == 0 + && jcp.oc % jcp.oc_block == 0 + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + { // balancing + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); + jcp.nthr = nthr; + jcp.nthr_mb = nthr_mb; + jcp.nthr_g = nthr_g; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + + return status::success; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.ver == ver_4fma) { + if (jcp.is_1stconv) { + const size_t tr_src_size = + jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } else { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic; + const size_t min_tr_src_size_per_thr + = jcp.ih * jcp.ic_block * jcp.tr_iw; + const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr + + jcp.tr_src_num_guard_elems; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } + + /* prepare synchronization contexts */ + if (jcp.nthr_oc_b > 1) { + const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * tr_src_bctx_size); + } + } + + if (jcp.nthr_mb > 1) { + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic + * jcp.kh * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const size_t wei_bia_reduction_size = wei_size + bia_size; + + scratchpad.book(key_conv_wei_bia_reduction, + jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1)); + scratchpad.book(key_conv_wei_bia_reduction_bctx, + sizeof(simple_barrier::ctx_t)); + } + + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::balance( + const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, + int &nthr_oc_b_, int &nthr_ic_b_) +{ + nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; + + const int max_threads = mkldnn_get_max_threads(); + + if (max_threads < j.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + + if (!mkldnn_thr_syncable() && j.ver == ver_4fma) { + // should not happen -- the driver is not ready + // for TBB-like non-synchronous threading yet + return; + } + + if (j.ver == ver_4fma && j.is_1stconv) { + nthr_g_ = 1; + nthr_oc_b_ = 1; + nthr_ic_b_ = nstl::min(j.nb_ic, max_threads); + nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb); + nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_; + return; + } + + nthr_g_ = j.ngroups; + const int nthr = max_threads / nthr_g_; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level optimizer + * tries to minimize memory consumption. few notes: + * (n1) unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + + const int src_coef = j.ver == ver_4fma ? 4 : 1; + const int dst_coef = 1; + const int wei_coef = 8; + + return 0 + + src_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id + / j.stride_d / j.stride_h / j.stride_w /* (n1) */ + + dst_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od + + wei_coef /* (n2) */ + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) + * j.kh * j.kw * j.kd * j.ic_block * j.oc_block; + }; + + int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, j.mb * j.od); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + + if (!mayiuse(avx512_mic)) { + auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + return 1 + * div_up(j.mb, nthr_mb) + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) + * div_up(j.nb_ic, nthr_ic_b); + }; + + /* step 2: search for a thread distribution with lower compute cost. + * the constrains: + * - memory cost cannot exceed 110% of the best found in the step 1 + * - unless compute cost is 133% lower than the current best case + * note: both constants were found empirically */ + int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + const bool opt1 = comp_cost <= best_comp_cost + && mem_cost < 1.1 * best_mem_cost; + const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost; + + if (opt1 || opt2) { + best_comp_cost = comp_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + } + + if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads) + nthr_mb_ = nstl::min(j.mb * j.od, max_threads); + nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; + + assert(nthr_ <= max_threads); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1)); +} + +template struct _jit_avx512_common_conv_fwd_kernel; +template struct _jit_avx512_common_conv_fwd_kernel; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp new file mode 100644 index 0000000000..f76770797a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { + + _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_common_conv_fwd_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_inp = r8; + reg64_t reg_ker = r9; + reg64_t reg_out = r10; + + reg64_t reg_inp_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_out_prf = r13; + reg64_t reg_owb = r12; + + reg64_t aux_reg_inp = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_inp_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t reg_channel = rsi; + reg64_t reg_bias = rdx; + + reg64_t aux_reg_ker_d = r9; + reg64_t aux_reg_inp_d = rbx; + reg64_t aux_reg_inp_d_prf = r13; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_relu_ns = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_tmp = rbp; + + reg64_t reg_ic_loop = rdx; + reg64_t reg_inp_loop = rsi; + + reg64_t reg_init_flag = r13; + reg64_t reg_bias_ptr = param; + + reg64_t aux_reg_ic = r12; + reg64_t reg_binp = rax; + reg64_t reg_bout = r11; + reg64_t aux1_reg_inp = rbx; + reg64_t aux_reg_out = abi_not_param1; + + reg64_t reg_long_offt = r11; + reg64_t reg_out_long_offt = r14; + + inline Vmm vmm_ker(int i_ic) { + assert(i_ic < 4); + return Vmm(ker_reg_base_idx + i_ic); + } + + inline Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Vmm(idx); + } + + inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + + Xbyak::Reg64 imm_addr64 = r15; + Vmm vmm_wei = Vmm(31); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); + inline void compute_loop(int ur_w, int pad_l, int pad_r); + + void generate(); + + inline size_t get_output_offset(int oi, int n_oc_block) { + return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh + * jcp.ow * jcp.od + oi) * jcp.oc_block; + } + + inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { + size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; + size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; + return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) + + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); + } + + inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { + return jcp.typesize_in * jcp.oc_block + * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd + + (ic + ker_number) + ki * jcp.ic_block); + } + + inline int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + + inline int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } +}; + +struct jit_avx512_common_conv_fwd_kernel { + + jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_common_conv_fwd_kernel() { + delete xmm_kernel_; + delete zmm_kernel_; + } + + enum { + typesize = sizeof(float) + }; + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + void(*jit_ker)(jit_conv_call_s *); + _jit_avx512_common_conv_fwd_kernel *zmm_kernel_; + _jit_avx512_common_conv_fwd_kernel *xmm_kernel_; +}; + +struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { + + jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_dst = r8; + reg64_t reg_ker = r9; + reg64_t reg_src = r10; + + reg64_t reg_dst_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_src_prf = r13; + + reg64_t aux_reg_dst = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_dst_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t aux_reg_dst_d_prf = r13; + reg64_t aux_reg_dst_d = rbx; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t aux_reg_ker_d = r9; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_channel = rsi; + + reg64_t reg_tmp = rbp; + reg64_t reg_long_offt = r14; + + inline Xbyak::Zmm zmm_ker(int i_ic) { + assert(i_ic < 4); + return Xbyak::Zmm(ker_reg_base_idx + i_ic); + } + inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Xbyak::Zmm(idx); + } + + Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { + + jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, + memory_desc_t &diff_dst_md); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum {typesize = sizeof(float)}; + static const int max_ur_w; + + reg64_t param = abi_param1; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t reg_tmp = r14; + reg64_t reg_long_offt = r14; + + reg64_t ki = r11; + reg64_t reg_kd_count = r12; + reg64_t reg_oi = r12; + reg64_t reg_d_index = r13; + reg64_t reg_input_d = r15; + reg64_t reg_output_d = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t reg_bias = rbx; + + inline void bias_kernel(); + inline void maybe_zero_kernel(); + inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, + int max_ur_w); + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_ic_block_step(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound = false); + inline void compute_ic_block_step_fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_ic_block_step_4fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_step_disp(); + inline void compute_oh_loop_common(); + inline void compute_d_loop_common(); + + inline bool compute_full_spat_loop(); + inline bool flat_4ops_compute(); + + inline void compute_loop(); + + void generate(); + + static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, + int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp new file mode 100644 index 0000000000..1bdcd0d6a8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp @@ -0,0 +1,1163 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +#ifndef KERNEL_SIZE_THRESHOLD +#define KERNEL_SIZE_THRESHOLD 16 +#endif + +#define MIN_REQUIRED_DIMN_REG_BLOCK 14 + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + if (jcp.ver == ver_4fma) + return jcp.mb >= 32; + else + return jcp.mb >= 16; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: + case L3: + default: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block, + float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate( + bool is_beta_zero) +{ + // const int dimK_simd_block = jcp.dimK_reg_block; + + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][tile] += + // A[dimM_block][dimK_block][dimK_reg_block] * + // broadcast(B[dimK_block][tile][dimK_reg_block]); + // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block], + // so we load it before the loop on tile + // 2) the loop on tile must be fully unrolled. Don't know about the one on + // dimK_reg_block. I think it should be + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + + prefetcher_t L1_pf(this, reg_srcB, L1, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + prefetcher_t L2_pf(this, reg_srcB, L2, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { + // First, we zero the accumulators if first nb_ic iteration, + // otherwise we load them + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (is_beta_zero) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + 64 * tile]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_dimK_reg_block; i++) + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + 64 * (offset + i)]); + }; + + // Used when doing double buffering + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block += inc_dimK_reg_block) { + int current; + /* Loading the next vector from A */ + current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block + inc_dimK_reg_block) + % (2 * inc_dimK_reg_block); + load_A(next, dimK_reg_block + inc_dimK_reg_block); + } else { + next = 0; + load_A(next, dimK_reg_block); + } + /* Performing the fmas */ + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (jcp.ver != ver_avx512_core) + L1_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + if (jcp.ver == ver_4fma) + v4fmaddps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4)); + else + vfmadd231ps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, + true)); + if (jcp.ver != ver_avx512_core) + L2_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + } + } + + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + } + + + auto store_output = [=](bool output_is_aligned) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(zword[reg_dstC + 64 * tile], zmm); + else + vmovups(zword[reg_dstC + 64 * tile], zmm); + } + }; + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimN_reg_block * 64); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + + if (mayiuse(avx512_core)) + return status::unimplemented; + else if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // Checking conditions not supported by these kernels + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + return status::success; +} + + +status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) { + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK) + && (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block < current_best); + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block); + + if (jcp.dimN_reg_block >= jcp.nb_reg) { + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block > current_best); + }; + + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, 1, test_cond_dimN_reg_block); + } + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + jcp.dimM_simd_block = 16; + /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/ + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_simd_block, .5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + // TODO: replace double buffering with nuple buffering to maximize register + // usage. + // the choice of the number of buffers will then come after choosing + // dimN_reg_block + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2); + else + jcp.zmm_start = 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + + jcp.sched_policy = WSCHED_INVALID; + set_wsched_DATA_W_S_G_D_avx512_common(jcp); + + assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D); + return status::success; +} + +bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) || + (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { + status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate() +{ + auto load_B = [=](int reg_idx, int offset) { + for (int i = 0; i < 4; i++) { + vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]); + } + }; + + preamble(); + int curr = 0; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int origB_offset = (j * alpha + i) * jcp.dimK_4fma; + size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block * + jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float); + mov(reg_transB_idx, transB_offset); + for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) { + /*double buffering to hide load latencies*/ + int next = (curr + 4) % 8; + if (i == 0 && tb == 0) { + load_B(0, origB_offset); + } + if (tb + 4 < (jcp.dimK_4fma -1)) { + load_B(next, origB_offset + 4); + } else if (i < alpha - 1) { + load_B(next, origB_offset + jcp.dimK_4fma); + } + + vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3)); + vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1)); + vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3)); + + vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9)); + vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9)); + + vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1)); + + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * tb * jcp.dimN_reg_block], + Zmm(curr+2)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 1) * jcp.dimN_reg_block], + Zmm(curr+3)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 2) * jcp.dimN_reg_block], + Zmm(8)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 3) * jcp.dimN_reg_block], + Zmm(9)); + curr = next; + + } + } + } + postamble(); + ret(); +} +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate( + bool is_first_tile) +{ + // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) + // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++) + // for (int nb_tile_block_ur = 0; nb_tile_block_ur < + // jcp.nb_tile_block_ur; nb_tile_block_ur++) + // for (int tile_block_ur = 0; tile_block_ur < + // jcp.tile_block_ur; tile_block_ur++) + // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3) + // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] += + // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block] + // * + // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur]) + auto inner_loops = [=]() { + int inc_fma = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + prefetcher_t L1_pf(this, reg_srcB, L1, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + prefetcher_t L2_pf(this, reg_srcB, L2, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_fma; i++) { + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + + sizeof(float) * jcp.dimM_simd_block * (offset + i)]); + } + }; + + Label dimM_block_loop, dimK_block_loop, dimN_block_loop; + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { /************* OC_block (M) loop ***********/ + if (jcp.dimN_block > 1) { + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + } + { /*************** IC_block (N) loop *********/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + if (is_first_tile) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * + sizeof(float)]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { /************* nb_tile_ur(K) loop ********/ + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block++) { + int srcB_offset = dimK_reg_block * jcp.dimK_4fma + * jcp.dimN_reg_block; + for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma; + dimK_4fma += inc_fma) { + int current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma) + % (2 * inc_fma); + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma); + } else { + next = 0; + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma); + } + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; + ++dimN_reg_block) { + L1_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + L2_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + if (jcp.ver == ver_4fma) { + int srcB_trans_offset = (dimK_4fma / 4) * 64 + + dimK_4fma % 4; + v4fmaddps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * ( + srcB_offset + + srcB_trans_offset + + (dimN_reg_block % 4) * 16 + + (dimN_reg_block / 4) * 4))); + } else { + vfmadd231ps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * (srcB_offset + dimN_reg_block), + true)); + } + } + } + } + } + + add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma + * jcp.dimM_simd_block * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + /******** Write C back to memory *******/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + vmovups(zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * sizeof(float)], + zmm); + } + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block + * sizeof(float)); + if (jcp.dimN_block > 1) { + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + } + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimN_block * jcp.dimK_block + * jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + // register used to handle long fma encoding + preamble(); + mov(reg_srcA, reg_srcA_const); + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { +bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw; + lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma + * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block + * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} + +bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block; + lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} +} // namespace + +status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp) +{ + /*************** Choose dimN_reg_block (ic_simd_block) + * *******************************/ + jcp.dimN = jcp.ic; + /*Hardcoded to 16 because N = ic for bwd weights and + innermost dimension for ic is assumed 16 in src transforms. This + choice covers load latencies while maintaining simplicity of kernel + for POR topologies. FIXME in future??: Will not work for future topologies + when ic%16 != 0*/ + jcp.dimN_reg_block = jcp.ic_simd_block; + + /****************************** Choose dimK_block + * **************************/ + // No freedom for choosing dimM_simd_block because ic_simd_block + // is determined by input data format + jcp.dimM_simd_block = jcp.oc_simd_block; + + auto test_cond1bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond2bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f) + && (dimK_block > current_best); + }; + + auto test_cond2_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block); + if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block); + + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block); + if (jcp.dimK_reg_block < jcp.dimK_block) { + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1_dimK_block); + } + jcp.dimK_block /= jcp.dimK_reg_block; + jcp.dimK_nb_block + = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + /***************************** Chose dimN block + * ****************************/ + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block, + jcp.dimN_reg_block, 0.5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.ic_block = jcp.dimN_block; + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /********************************* Choose dimM block + * ************************/ + jcp.dimM = jcp.oc; + + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1, + jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block, + 1.0f) + && (dimM_block > current_best) + && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2; + }; + + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + jcp.sched_policy = WSCHED_WEI_S_D_G_W; + return status::success; +} + +status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) +{ + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (mayiuse(avx512_core)) + return status::unimplemented; + if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /*************************** New Kernel Parameters + * *****************************/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + jcp.dimK_4fma = 1; + jcp.tile_4fma_padding = 0; + +#define MAX_4FMA_UR 8 + if (jcp.ver == ver_4fma) { + auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma, + int current_best) { + return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR) + && (dimK_4fma > current_best); + }; + jcp.dimK_4fma = get_divisor_satisfying_cond( + jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma); + if (jcp.dimK_4fma == 1) + jcp.dimK_4fma = 4; + if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0) + jcp.tile_4fma_padding = jcp.dimK_4fma + - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma); + } + + jcp.tile_4fma = jcp.dimK_4fma; + /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src + * transform + * will not work correctly, this is solved by applying padding.*/ + jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2; + else + jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp); + assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W); + + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + + return res; + +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp new file mode 100644 index 0000000000..6c117143f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp @@ -0,0 +1,179 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +//alpha determines the output tile_size +constexpr int alpha = 6; +constexpr int tile_size = 4; +//simd length used for vectorization +constexpr int simd_w = 16; + +struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator { + _jit_avx512_common_conv_winograd_data_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter + = (decltype(gemm_loop_ker_first_iter)) this->getCode(); + + //************** Subsequent iterations kernel **************// + if (jcp.dimK_nb_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + +protected: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_beta_zero); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; +}; + +struct jit_avx512_common_conv_winograd_fwd_kernel_f32 + : _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); +}; + +struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32 + : public _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + + //******************* First iter kernel ********************// + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr; + } + + if (jcp.tile_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + if (jcp.ver == ver_4fma) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->transpose_ker_generate(); + transpose_4fma_ker = (decltype(transpose_4fma_ker))addr; + } + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*transpose_4fma_ker)(float *, float *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_first_tile); + void transpose_ker_generate(); + + reg64_t reg_origB = abi_param2; + reg64_t reg_transB = abi_param1; + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA_const = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_sp = rsp; + reg64_t reg_srcA = r9; + reg64_t reg_nb_ic = r10; + reg64_t reg_loop_cpt = r11; + reg64_t reg_transB_idx = r13; + + /* Registers used by new kernel */ + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r12; + reg64_t reg_dimN_block_loop_cnt = r11; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp new file mode 100644 index 0000000000..abddc19221 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp @@ -0,0 +1,1526 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define PIPELINE(field) \ + do { \ + p.field = p.field ## _prf; \ + p.field ## _prf = field; \ + } while (0) + +inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int kd_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker, + jit_conv_call_s &p, const void *src, const void *dst, const void *filt, + const void *bias, int channel, int kh_padding, int kd_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int d_index, int d_worksize, + int kd_padding /* kd_work_size */, size_t kd_offset) { + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kd_padding); + PIPELINE(d_worksize); + PIPELINE(d_index); + PIPELINE(kd_offset); + + if (p.src) + ker(&p); +} +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +void jit_avx512_common_convolution_fwd_t::prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const { + if (!pd()->wants_padded_bias()) return; + + auto padded_bias = scratchpad.template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding); + utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, + (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding); + bias = padded_bias; +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + auto bias_w = bias ? bias + g_oc : nullptr; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src_w, dst_w, wht_w, bias_w, icb, 1, owb); + + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + auto bias_w = bias ? bias + g_oc : nullptr; + + for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) { + int ih_b = -jcp.t_pad + oh_b * jcp.stride_h; + + auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s); + auto src_w + = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s); + auto wht_w + = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); + ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_b, ij = ih_b; + oj < min(oh_e, oh_b + jcp.h_blocking); + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up(max(0, ij - jcp.ih + + (jcp.kh - 1) * dilate_h + 1), dilate_h); + int kh_padding = nstl::max( + 0, jcp.kh - i_t_overflow - i_b_overflow); + + auto aux_src = src_c + + i_t_overflow * dilate_h * src_h_stride; + auto aux_wht = wht_w + i_t_overflow * wht_h_stride; + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, aux_src, dst_c, aux_wht, bias_w, icb, + kh_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template +void jit_avx512_common_convolution_fwd_t:: +execute_forward_3d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + parallel(0, [&](const int ithr, const int nthr) { + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int start{0}, end{0}, start_copy; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh + * jcp.nb_ow; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_d_stride = src_d.blk_off(0, 0, 1); + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + int id_s = -jcp.f_pad + od_s * jcp.stride_d; + + int dilate_d = jcp.dilate_d + 1; + int d_t_overflow = div_up(max(0, -id_s), dilate_d); + int d_b_overflow = div_up( + max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), + dilate_d); + int kd_padding = nstl::max(0, + jcp.kd - d_t_overflow - d_b_overflow); + + auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s, + iw_s) + d_t_overflow * dilate_d * src_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2) + + d_t_overflow * wht_d_stride; + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_s, ij = ih_s; + oj < oh_e; ++oj, ij += jcp.stride_h) + { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + + 1), + dilate_h); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, + src_c + i_t_overflow * dilate_h * src_h_stride, + dst_c, wht_w + i_t_overflow * wht_h_stride, + bias_w, icb, kh_padding, kd_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template struct jit_avx512_common_convolution_fwd_t; + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_1d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}; + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n, + jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w, diff_dst_w, wht_w, 0, ocb, 1); + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups, + n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_2d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template +void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_3d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); + size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); + size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; + bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + int d_len = 0, d_lo = 0, d_oj = 0; + if (is_fast_path_d) { // dilate == 0 && stride == 1 + int d_t_overflow = max(0, jcp.kd - 1 - id_s + - jcp.f_pad); + int d_b_overflow = max(0, jcp.kd - jcp.id + id_s + - jcp.back_pad); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow; + } else if (jcp.dilate_d != 0) { // stride == 1 + int dilate_d = jcp.dilate_d + 1; + // Note: use div_up to account for "holes" in filter + int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + - id_s - jcp.f_pad), dilate_d); + int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1 + - jcp.id + id_s - jcp.back_pad), dilate_d); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d; + } else { // dilate == 0 + int d_t_overflow = max(0, (jcp.kd - 1 - id_s + - jcp.f_pad) / jcp.stride_d); + int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s + - jcp.back_pad) / jcp.stride_d); + int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1 + + jcp.back_pad - id_s) % jcp.stride_d); + int overflow_kd_lo = (id_s + jcp.f_pad) + % jcp.stride_d; + + d_len = (overflow_kd_hi - overflow_kd_lo) + / jcp.stride_d + 1 - d_t_overflow + - d_b_overflow; + d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; + d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d; + } + assert(d_len >= 0); + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb) + + id_s * diff_src_d_stride; + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2) + + d_oj * diff_dst_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb) + + d_lo * wht_d_stride; + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path_h) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len, d_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1, 1); + }); +} + +template struct jit_avx512_common_convolution_bwd_data_t; + +template +jit_avx512_common_convolution_bwd_weights_t:: +jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) + , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) +{ + const auto &j = pd()->jcp_; + + nthr_ = j.nthr; + nthr_mb_ = j.nthr_mb; + nthr_g_ = j.nthr_g; + nthr_oc_b_ = j.nthr_oc_b; + nthr_ic_b_ = j.nthr_ic_b; + + kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j); + + if (j.ver == ver_4fma) + trans_kernel_ = create_trans_src(&j); + + if (nthr_mb_ > 1) + acc_ker_ = new cpu_accumulator_1d_t(); + + reducer_bias_ = + new cpu_reducer_t(pd()->reducer_bia_conf_); +} + +template +struct jit_avx512_common_convolution_bwd_weights_t::thread_info_t { + const src_data_t *src; + const diff_dst_data_t *diff_dst; + const diff_weights_data_t *diff_weights; + diff_weights_data_t *diff_bias; + + const memory_tracking::grantor_t scratchpad; + + src_data_t *tr_src; + simple_barrier::ctx_t *tr_src_bctx; + + diff_dst_data_t *tr_diff_dst; + simple_barrier::ctx_t *tr_diff_dst_bctx; + + diff_weights_data_t *wei_bia_reduction; + simple_barrier::ctx_t *wei_bia_reduction_bctx; + + int ithr; + int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb; + int ithr_but_oc; + int ithr_but_ic; + + int img_start = 0, img_end = 0, img_work; + int g_start = 0, g_end = 0, g_work; + int oc_b_start = 0, oc_b_end = 0, oc_b_work; + int ic_b_start = 0, ic_b_end = 0, ic_b_work; + + thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self, + const exec_ctx_t &ctx, int ithr) + : scratchpad(self->scratchpad(ctx)), ithr(ithr) + { + diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + diff_bias = self->pd()->wants_padded_bias() + ? scratchpad.template get( + key_conv_padded_bias) + : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + + tr_src = scratchpad.template get(key_conv_tr_src); + tr_src_bctx = scratchpad.template get( + key_conv_tr_src_bctx); + + tr_diff_dst = scratchpad.template get( + key_conv_tr_diff_dst); + tr_diff_dst_bctx = scratchpad.template get( + key_conv_tr_diff_dst_bctx); + + wei_bia_reduction = scratchpad.template get( + key_conv_wei_bia_reduction); + wei_bia_reduction_bctx = scratchpad.template get( + key_conv_wei_bia_reduction_bctx); + + ithr_ic_b = ithr % self->nthr_ic_b_; + ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; + ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; + ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; + + ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ + + ithr_ic_b; + + ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ + + ithr_oc_b; + + const auto &jcp = self->kernel_->jcp; + + /* reduction dimension */ + balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end); + img_work = img_end - img_start; + + /* independent dimensions */ + balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); + g_work = g_end - g_start; + + balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, + oc_b_end); + oc_b_work = oc_b_end - oc_b_start; + + balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, + ic_b_end); + ic_b_work = ic_b_end - ic_b_start; + } +}; + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + // TODO: use memory descriptor with the same fmt as src (or use a macro :)) + auto tr_src_off = [&](int ithr_mb, int ic, int ij) { + const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; + const size_t tr_chn_size = tr_row_size * jcp.ih; + const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups; + + return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size; + }; + + auto uker_trans = [&](int img) { + const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih; + + int start{0}, end{0}; + balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); + const int my_work = end - start; + + int g{0}, ic_b{0}, j{0}; + nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih); + g += ti->g_start; + ic_b += ti->ic_b_start; + + const int _ic = g * jcp.nb_ic + ic_b; + src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)]; + src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.iw * jcp.ic_block; + const int tr_src_stride = jcp.tr_iw * jcp.ic_block; + + const int pf_depth = 2; + struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth]; + + for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) { + pf_circ_buf[iwork % pf_depth] = {src1, tr_src1}; + + if (iwork >= pf_depth - 1) { + int old_idx = (iwork - pf_depth + 1) % pf_depth; + auto ctx = jit_trans_src_t::ctx_t(); + ctx.src = pf_circ_buf[old_idx].src; + ctx.tr_src = pf_circ_buf[old_idx].tr_src; + ctx.src_prf = src1; + ctx.tr_src_prf = tr_src1; + (*trans_kernel_)(&ctx); + } + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#if 0 + // reference transposition + const int l_pad = jcp.l_pad; + const int iwlp = l_pad + jcp.iw; + const int tr_iw = jcp.tr_iw; + + for (size_t iwork = start; iwork < end; iwork++) { + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = 0; i < l_pad; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = l_pad; i < iwlp; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] + = (src_data_t)src1[(i - l_pad) * 16 + j]; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = iwlp; i < tr_iw; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#endif + }; + + if (jcp.is_1stconv && jcp.ver == ver_4fma) { + /* prepare contexts */ + auto tr_ctx = jit_trans_src_t::ctx_t(); + tr_ctx.tr_src = ti->tr_src + + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld; + + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1)); + tr_ctx.nthr_oc_b = nthr_oc_b_; + int ih_start{0}, ih_end{0}; + balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end); + tr_ctx.tr_src_ih_start = ih_start; + tr_ctx.tr_src_ih_end = ih_end; + tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc; + + auto p = jit_conv_call_s(); + p.src = tr_ctx.tr_src; + + /* zero diff_bias if applicable */ + if (jcp.with_bias && ti->ithr_ic_b == 0) { + assert(jcp.oc_block == 16); + for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) { + diff_weights_data_t *db = &diff_bia[oc_b * 16]; + for (int o = 0; o < 16; ++o) + db[o] = 0; + } + } + + for (int img = ti->img_start; img < ti->img_end; ++img) { + p.flags = (img == ti->img_start) * FLAG_MB_FIRST; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _ic = g * jcp.nb_ic + ic_b; + tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)]; + + (*trans_kernel_)(&tr_ctx); + + if (ic_b == 0) + p.flags |= FLAG_IC_FIRST; + else + p.flags &= ~FLAG_IC_FIRST; + + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + const int _oc = g * jcp.nb_oc + oc_b; + p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + + const size_t off = + wht_blk_off(diff_weights_d, g, oc_b, ic_b); + p.filt = diff_wei + off; + p.bias = diff_bia + _oc * jcp.oc_block; + + kernel_->jit_ker(&p); + } + } + } + } + } else { + for (int img = ti->img_start; img < ti->img_end; ++img) { + auto p = jit_conv_call_s(); + + if (jcp.ver == ver_4fma) { + /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ + using simple_barrier::barrier; + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + uker_trans(img); + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + } + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img, _oc)], + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + 0, (img == ti->img_start), 0); + + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off( + diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + 0, 0, 0); + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights_3d(const thread_info_t *ti) const +{ + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size + = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + const int input_step = jcp.ih * jcp.iw * inp_mult; + const int output_step = jcp.ow * jcp.oh * jcp.oc_block; + int img{0}, od_s{0}; + int img_start = ti->img_start, img_end = ti->img_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + while (img_start < img_end) { + auto p = jit_conv_call_s(); + + int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int ik_overlap = nstl::max(0, id_s - jcp.f_pad); + const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s); + const int kd_back_pad + = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd); + int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw + * jcp.ic_block * jcp.oc_block * jcp.typesize_out; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + auto src = &ti->src[src_d.blk_off(img, _ic) + + ik_overlap * input_step]; + auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc) + + od_s * output_step]; + + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst, + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + diff_bia + _oc * 16, (img == img_first), od_s, od_e, + jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off); + + if (ic_b == 0) p.flags = 0; + else p.flags = 1; + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, + &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off(diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + diff_bia, 0, 0, 0, 0, 0); + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kh; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh; + const int kh = sub_ic_b_kh_start % jcp.kh; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh); + + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) { + if (ti->ithr == 0) + acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias, + diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights_3d(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw + * jcp.kd; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kd; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd; + const int kd = sub_ic_b_kh_start % jcp.kd; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd); + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias(const thread_info_t *ti) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + auto rb = this->reducer_bias_; + assert(nthr_ == rb->balancer().nthr_); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t( + ti->scratchpad, prefix_reducer_bia); + + const auto &jcp = kernel_->jcp; + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return; + + const int b_job_start = rb->balancer().ithr_job_off(ti->ithr); + const int b_njobs = rb->balancer().ithr_njobs(ti->ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ti->ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const diff_dst_data_t *d_dst + = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr, + ti->diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0; + for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + + rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad); +} + +template +void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias_3d(const thread_info_t *ti) const { + + const auto &jcp = kernel_->jcp; + + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh + * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size; + + if (nthr_mb_ > 1) mkldnn_thr_barrier(); + + if (ti->ithr == 0) + { + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template +void jit_avx512_common_convolution_bwd_weights_t::prepare_scratchpad_data(const exec_ctx_t &ctx) const +{ + const auto &j = pd()->jcp_; + auto scratchpad = this->scratchpad(ctx); + + if (j.ver == ver_4fma) { + if (!j.is_1stconv) { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic; + const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw; + + auto tr_src = scratchpad.template get(key_conv_tr_src); + /* to avoid NaNs in computations we zero tail num_guard_elems for + * each possible thread group */ + + for (int ithr = 1; ithr <= max_nthr; ++ithr) { + src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr]; + for (int i = 0; i < j.tr_src_num_guard_elems; ++i) + ts[i] = 0; + } + } + + if (j.nthr_oc_b > 1) { + const int tr_src_bctx_size = j.nthr / j.nthr_oc_b; + auto tr_src_bctx = scratchpad.template get( + key_conv_tr_src_bctx); + for (int i = 0; i < tr_src_bctx_size; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + } + + if (nthr_mb_ > 1) { + simple_barrier::ctx_init(scratchpad.template get( + key_conv_wei_bia_reduction_bctx)); + } + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); +} + +template +void jit_avx512_common_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { + prepare_scratchpad_data(ctx); + + parallel(nthr_, [&](const int ithr, const int nthr) { + assert(nthr_ == nthr); + + thread_info_t thread_info(this, ctx, ithr); + + if (utils::one_of(pd()->ndims(), 3, 4)) { + compute_diff_weights(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); + if (pd()->with_bias()) compute_diff_bias(&thread_info); + } else if (pd()->ndims() == 5) { + compute_diff_weights_3d(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); + if (pd()->with_bias()) compute_diff_bias_3d(&thread_info); + } else { + assert(false); + } + }); + + /* TODO: put that into compute_diff_bias() */ + if (pd()->wants_padded_bias()) { + auto diff_bias = scratchpad(ctx).template get( + key_conv_padded_bias); + auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +template struct jit_avx512_common_convolution_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp new file mode 100644 index 0000000000..3341c3ebe0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_transpose_src_utils.hpp" +#include "jit_avx512_common_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad, + jcp_); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_common_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + ~jit_avx512_common_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_forward_1d(ctx); + else if (pd()->ndims() == 4) + execute_forward_2d(ctx); + else if (pd()->ndims() == 5) + execute_forward_3d(ctx); + else + assert(false); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); + + return status::success; + } + +private: + void prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_fwd_kernel *kernel_; +}; + +template +struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, + data_type::undef, diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_, + *desc(), *diff_src_md(), *weights_md(), *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i, + OIdhw16o16i, gOIdhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx512_common_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_backward_data_1d(ctx); + else if (pd()->ndims() == 4) + execute_backward_data_2d(ctx); + else if (pd()->ndims() == 5) + execute_backward_data_3d(ctx); + else + assert(false); + return status::success; + } + +private: + void execute_backward_data_1d(const exec_ctx_t &ctx) const; + void execute_backward_data_2d(const exec_ctx_t &ctx) const; + void execute_backward_data_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_; +}; + +template +struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_weights_type, + diff_weights_type, diff_dst_type, data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), src_md_, diff_weights_md_, + diff_bias_md_, diff_dst_md_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + return status; + } + + jit_conv_conf_t jcp_; + typename cpu_reducer_t::conf_t reducer_bia_conf_; + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + } + }; + + jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd); + ~jit_avx512_common_convolution_bwd_weights_t() { + delete kernel_; + if (trans_kernel_) + delete trans_kernel_; + if (acc_ker_) + delete acc_ker_; + delete reducer_bias_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type diff_weights_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + void prepare_scratchpad_data(const exec_ctx_t &ctx) const; + struct thread_info_t; + void compute_diff_weights(const thread_info_t *) const; + void compute_diff_weights_3d(const thread_info_t *) const; + void reduce_diff_weights(const thread_info_t *) const; + void reduce_diff_weights_3d(const thread_info_t *) const; + void compute_diff_bias(const thread_info_t *) const; + void compute_diff_bias_3d(const thread_info_t *) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_; + + jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_; + jit_trans_src_t *trans_kernel_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp new file mode 100644 index 0000000000..62247c0264 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp @@ -0,0 +1,1215 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution_winograd.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +namespace { + +unsigned int LLC_cache_size = get_cache_size(3, false); + +void inline load_ps(float *dest, const float *src_mem) { +#ifdef __INTEL_COMPILER + __m512 *Iv512 = (__m512 *)dest; + Iv512[0] = _mm512_load_ps(src_mem); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v]; +#endif +} + +void inline store_output(float *dest, const float *data, bool streamout) { +#ifdef __INTEL_COMPILER + if (streamout) + _mm512_stream_ps(dest, *((__m512 *)data)); + else + _mm512_store_ps(dest, *((__m512 *)data)); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} + +void inline accum_output( + float *dest, float *data, bool streamout, bool with_relu_postsum) { +#ifdef __INTEL_COMPILER + __m512 _data = _mm512_loadu_ps(data); + __m512 _dest = _mm512_loadu_ps(dest); + _data = _mm512_add_ps(_data, _dest); + if (with_relu_postsum) + _data = _mm512_max_ps(_data, _mm512_setzero_ps()); + if (streamout) + _mm512_stream_ps(dest, _data); + else + _mm512_store_ps(dest, _data); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + data[v] += dest[v]; + + if (with_relu_postsum) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + if (data[v] < 0.f) + data[v] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} +} + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) { + float Fw[6][16]; + float T[6][3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +#pragma unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * F[2][i][j][k]; + t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k]; + t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k]; + + T[0][i][k] = 1.13777777777778f * F[0][i][j][k]; + T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k]; + T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k]; + T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k]; + T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k]; + T[5][i][k] = F[2][i][j][k]; + } + } +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * T[i][2][k]; + t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k]; + t2[k] = t0[k] + 0.119514472455649f * T[i][0][k]; + + Fw[0][k] = 1.13777777777778f * T[i][0][k]; + Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k]; + Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k]; + Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k]; + Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k]; + Fw[5][k] = T[i][2][k]; +#pragma unroll + for (int l = 0; l < 6; l++) { + Fw_[i][l][j][k] = Fw[l][k]; + } + } + } + } +} + +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) { + float T[4][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][v] + Mw[2][i][v]; + t1[v] = Mw[3][i][v] + Mw[4][i][v]; + t2[v] = Mw[1][i][v] - Mw[2][i][v]; + t3[v] = Mw[3][i][v] - Mw[4][i][v]; + + T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v]; + T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f; + T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v]; + } + } +#pragma unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = T[i][1][v] - T[i][2][v]; + t3[v] = T[i][3][v] - T[i][4][v]; + + O[i][0][v] = t0[v] + t1[v] + T[i][0][v]; + O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f; + O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v]; + } + } +} + + +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]) +{ + const float rcp3 = 1.0f / 3.0f; + const float rcp4 = 1.0f / 4.0f; + const float rcp6 = 1.0f / 6.0f; + const float rcp12 = 1.0f / 12.0f; + const float rcp24 = 1.0f / 24.0f; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float T[6][4][16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = F[2][i][j] * rcp6; + t1[j] = F[0][i][j] * -rcp6 - t0[j]; + t2[j] = F[0][i][j] * rcp24 + t0[j]; + t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6; + t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3; + + T[0][i][j] = F[0][i][j] * rcp4; + T[1][i][j] = t1[j] - t3[j]; + T[2][i][j] = t1[j] + t3[j]; + T[3][i][j] = t2[j] + t4[j]; + T[4][i][j] = t2[j] - t4[j]; + T[5][i][j] = F[3][i][j]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = T[i][2][j] * rcp6; + t1[j] = T[i][0][j] * -rcp6 - t0[j]; + t2[j] = T[i][0][j] * rcp24 + t0[j]; + t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6; + t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3; + + Fw[i][0][j] = T[i][0][j] * rcp4; + Fw[i][1][j] = t1[j] - t3[j]; + Fw[i][2][j] = t1[j] + t3[j]; + Fw[i][3][j] = t2[j] + t4[j]; + Fw[i][4][j] = t2[j] - t4[j]; + Fw[i][5][j] = T[i][3][j]; + } + } +} + +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[4][6][16]; + float M_[3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l]; + t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l]; + t2[l] = t1[l] * 4.0f + Mw[5][i][j][l]; + + T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l]; + T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) + + 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]); + T[2][i][l] = t0[l] + t2[l]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = T[i][1][l] + T[i][2][l]; + t1[l] = T[i][3][l] + T[i][4][l]; + t2[l] = t1[l] * 4.0f + T[i][5][l]; + + M_[0][l] = T[i][0][l] + t0[l] + t1[l]; + M_[1][l] = (T[i][1][l] - T[i][2][l]) + + 2.0f * (T[i][3][l] - T[i][4][l]); + M_[2][l] = t0[l] + t2[l]; + + for (int k = 0; k < 3; k++) { + M[i][k][j][l] = M_[k][l]; + } + } + } + } +} + +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]) +{ + float T[6][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float t5[16]; + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = I[2][i][v] * -2.25f + I[4][i][v]; + t1[v] = I[1][i][v] * -2.25f + I[3][i][v]; + t2[v] = I[2][i][v] * -0.390625f + I[4][i][v]; + t3[v] = I[1][i][v] * -0.390625f + I[3][i][v]; + t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v]; + t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v]; + + T[0][i][v] = I[2][i][v] * -2.640625f + t4[v]; + T[1][i][v] = t1[v] * 0.625f + t0[v]; + T[2][i][v] = t1[v] * -0.625f + t0[v]; + T[3][i][v] = t3[v] * 1.5f + t2[v]; + T[4][i][v] = t3[v] * -1.5f + t2[v]; + T[5][i][v] = I[3][i][v] * -2.640625f + t5[v]; + } + } + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * -2.25f + T[i][4][v]; + t1[v] = T[i][1][v] * -2.25f + T[i][3][v]; + t2[v] = T[i][2][v] * -0.390625f + T[i][4][v]; + t3[v] = T[i][1][v] * -0.390625f + T[i][3][v]; + t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v]; + t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v]; + + Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v]; + Iw[i][1][v] = t1[v] * 0.625f + t0[v]; + Iw[i][2][v] = t1[v] * -0.625f + t0[v]; + Iw[i][3][v] = t3[v] * 1.5f + t2[v]; + Iw[i][4][v] = t3[v] * -1.5f + t2[v]; + Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v]; + } + } +} + +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]) +{ + float T[6][4][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = F[2][i][v] * 0.26890756302521f; + t1[v] = F[0][i][v] * -0.688403361344538f - t0[v]; + t2[v] = F[0][i][v] * 0.119514472455649f + t0[v]; + t3[v] = F[1][i][v] * 0.430252100840336f + + F[3][i][v] * 0.168067226890756f; + t4[v] = F[1][i][v] * 0.179271708683473f + + F[3][i][v] * 0.403361344537815f; + + T[0][i][v] = F[0][i][v] * 1.13777777777778f; + T[1][i][v] = t1[v] - t3[v]; + T[2][i][v] = t1[v] + t3[v]; + T[3][i][v] = t2[v] + t4[v]; + T[4][i][v] = t2[v] - t4[v]; + T[5][i][v] = F[3][i][v]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * 0.26890756302521f; + t1[v] = T[i][0][v] * -0.688403361344538f - t0[v]; + t2[v] = T[i][0][v] * 0.119514472455649f + t0[v]; + t3[v] = T[i][1][v] * 0.430252100840336f + + T[i][3][v] * 0.168067226890756f; + t4[v] = T[i][1][v] * 0.179271708683473f + + T[i][3][v] * 0.403361344537815f; + + Fw[i][0][v] = T[i][0][v] * 1.13777777777778f; + Fw[i][1][v] = t1[v] - t3[v]; + Fw[i][2][v] = t1[v] + t3[v]; + Fw[i][3][v] = t2[v] + t4[v]; + Fw[i][4][v] = t2[v] - t4[v]; + Fw[i][5][v] = T[i][3][v]; + } + } +} + +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[3][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float M_[3][16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v]; + t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v]; + t2[v] = t1[v] * 2.25f + Mw[5][i][j][v]; + + T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v]; + T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) + + 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]); + T[2][i][v] = t0[v] * 0.390625f + t2[v]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = t1[v] * 2.25f + T[i][5][v]; + + M_[0][v] = T[i][0][v] + t0[v] + t1[v]; + M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) + + 1.5f * (T[i][3][v] - T[i][4][v]); + M_[2][v] = t0[v] * 0.390625f + t2[v]; + } + +pragma_unroll + for (int k = 0; k < 3; k++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + M[i][k][j][v] = M_[k][v]; + } + } + } + } +} + +template +void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp, bool streamout = true) +{ + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + const int wp_max = inpw + l_pad; + const int hp_max = inph + t_pad; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + + array_offset_calculator input(inp, + jcp.mb, jcp.dimK/simd_w, inph, inpw, + simd_w); + array_offset_calculator output(tinp, + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((t_pad <= ydim) && (ydim < hp_max)) { + float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ; + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((l_pad <= xdim) && (xdim < wp_max)) { + float *pinp_i = pinp_j + (xdim - l_pad) * 16; + load_ps(I[j][i], pinp_i); + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_I_4x4_3x3(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(tile_block, j, i, + nb_tile_block_ur, 0, 0, + tile_block_ur, 0)), + Iw[j][i], streamout); + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + array_offset_calculator input(wp, + jcp.oc/jcp.oc_simd_block, + jcp.ic/jcp.ic_simd_block, + jcp.kh, jcp.kw, + simd_w, simd_w); + array_offset_calculator output(twp, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + simd_w, simd_w); + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + float *base_inp = is_fwd + ? &(input(0, 0, j, i, v1, 0)) + : &(input(0, 0, 2 - j, 2 - i, v1, 0)); + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + if (is_fwd) + F[j][i][v1][v2] = *(base_inp + v2); + else + F[j][i][v2][v1] = *(base_inp + v2); + } + } + } + } + + trans_W_4x4_3x3(Fw, F); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2]; + } + } + } + } +} + +template +void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias, + bool streamout = true) { + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + /* Prepare for PostOps */ + bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1; + + array_offset_calculator input(toutp, + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Ow[j][i][v] = input(tile_block, 0, + j, i, + nb_tile_block_ur, 0, + tile_block_ur, v); + } + } + } + + trans_O_4x4_3x3(Ow, O); + + for (int j = 0; j < tile_size; j++) { + int ydim = tj * tile_size + j; + if (ydim < outh) { + float *pout_j = pout_b + ydim * outw * simd_w; + for (int i = 0; i < tile_size; i++) { + int xdim = ti * tile_size + i; + if (xdim < outw) { + float *pout_i = pout_j + xdim * simd_w; + if (is_fwd) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + O[j][i][v] += with_bias ? bias[v] : 0.f; + O[j][i][v] = true + && with_relu_presum && O[j][i][v] < 0.f + ? O[j][i][v] + * jcp.eltwise.alpha + : O[j][i][v]; + } + } + if (with_sum) + accum_output(pout_i, O[j][i], streamout, + with_relu_postsum); + else + store_output(pout_i, O[j][i], streamout); + } + } + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *Iw_temp, + void (*transpose_4fma_ker)(float *, float *)) +{ + + const int ifwp = conv.iw + conv.l_pad; + const int ifhp = conv.ih + conv.t_pad; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator Iw_trans_temp(Iw_temp, + alpha, alpha, conv.tile_4fma, simd_w); + array_offset_calculator input(inp, + conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w); + array_offset_calculator output(tinp, + conv.nb_ic, alpha, alpha, + conv.tile_block, conv.ic_block, + conv.nb_tile_block_ur, conv.tile_block_ur, + conv.ic_simd_block * conv.tile_4fma); + + int tile_base_index = + image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding); + int tile_4fma = 0; + int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((conv.t_pad <= ydim) && ydim < ifhp) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((conv.l_pad <= xdim) && xdim < ifwp) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input(0, 0, + ydim - conv.t_pad, + xdim - conv.l_pad, v); + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + trans_I_4x4_3x3(Iw, I); + + if (ver_4fma) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, + tile_4fma, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = Iw[j][i][v]; + } + } + } + tile_4fma++; + if (tile_4fma == conv.tile_4fma) { + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + tile_4fma = 0; + tile_block_ur++; + } + } else { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + } + + if (tile_block_ur == conv.tile_block_ur) { + tile_block_ur = 0; + ++nb_tile_block_ur; + } + if (nb_tile_block_ur == conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } + + if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) { + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = 0; + } + } + } + } + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + } +} + +template +void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *dbias) +{ + + const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator input(inp, + conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block); + array_offset_calculator output(tinp, + conv.nb_oc, alpha, alpha, + conv.tile_block, conv.oc_block, + conv.nb_tile_block_ur, + conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block); + + int tile_base_index = image * total_tiles; + int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma); + int nb_tile_block_ur = + (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if (ydim < conv.oh) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if (xdim < conv.ow) { + float *input_base = &(input(0, 0, ydim, xdim, 0)); + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input_base[v]; + } + if (with_bias && j < tile_size && i < tile_size) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + dbias[v] += input_base[v]; + } + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_W_3x3_4x4_wu(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, + tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + array_offset_calculator input(twp, + conv.nb_ic, conv.nb_oc, + alpha, alpha, + conv.oc_block, conv.ic_block, + conv.ic_simd_block, conv.oc_simd_block); + array_offset_calculator output(wp, + conv.oc/simd_w, conv.ic/simd_w, + conv.kh, conv.kw, + conv.ic_simd_block, conv.oc_simd_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < conv.oc_simd_block; k++) { + Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k); + } + } + } + } + + trans_O_3x3_4x4_wu(Fw, F); + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + store_output(&(output(0, 0, j, i, v, 0)), + F[j][i][v], true); + } + } + } +} + +template +void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Note that jcp.with_eltwise is true for both fused conv+relu primitive + * and conv primitive with PostOps with relu before sum + * (PostOps relu after sum is handled later) */ + auto output_transform = jcp.with_bias + ? (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data + : output_transform_data) + : (jcp.with_sum + ? output_transform_data + : output_transform_data)) + : (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data + : output_transform_data) + : (jcp.with_sum + ? output_transform_data + : output_transform_data)); + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator U( + scratchpad.template get(key_wino_U), + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_cache_size ? true : false; + + const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0; + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout); + }); + + parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block, + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr); + }); + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block, + [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) { + + kernel_->gemm_loop_ker_first_iter( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, 0, 0, 0, 0))); + for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, + 0, 0, 0))); + } + + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block, + [&](int img, int M_blk1, int M_blk2) { + + const int M_blk = M_blk1 * jcp.dimM_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), + bias_ptr, output_is_aligned); + + }); + + } +} + +template struct _jit_avx512_common_convolution_winograd_t; +template struct _jit_avx512_common_convolution_winograd_t; + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias[oc] = padded_bias[oc]; + } +} + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const { + auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ? + diff_src_transform_bwd_weights : + diff_src_transform_bwd_weights; + auto diff_dst_transform_bwd_weights_ver = jcp.with_bias + ? diff_dst_transform_bwd_weights + : diff_dst_transform_bwd_weights; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights(ptr_diff_weights, + jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator diff_bias(pd()->wants_padded_bias() + ? scratchpad.get(key_conv_padded_bias) : ptr_diff_bias, + jcp.oc/simd_w, simd_w); + + array_offset_calculator U( + scratchpad.get(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, jcp.oc_simd_block); + + array_offset_calculator M( + scratchpad.get(key_wino_M), + jcp.nb_oc, alpha, alpha, + jcp.tile_block, jcp.oc_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma, + jcp.oc_simd_block); + array_offset_calculator V( + scratchpad.get(key_wino_V), + jcp.nb_ic, alpha, alpha, + jcp.tile_block, jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block * jcp.tile_4fma); + + const int trans_buffer_size = alpha * alpha * jcp.tile_4fma + * jcp.ic_simd_block; + array_offset_calculator trans_buffer( + scratchpad.get(key_conv_tr_src), + nthreads, + trans_buffer_size); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), + nthreads, + jcp.oc); + +PRAGMA_OMP(parallel num_threads(nthreads)) + { + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + +PRAGMA_OMP(for nowait) + for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + diff_bias(bofm, v) = 0.0f; + } + } + + const int ithread = mkldnn_get_thread_num(); + + parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block, + [&](int img, int ifm1, int ifm2) { + float *transb = jcp.ver == ver_4fma + ? &(trans_buffer(ithread, 0)) + : NULL; + diff_src_transform_bwd_weights_ver(img, jcp, + &(src(img, ifm1 * jcp.ic_block + ifm2, + 0, 0, 0)), + &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)), + transb, + kernel_->transpose_4fma_ker); + }); + + parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block, + [&](int img, int ofm1, int ofm2) { + float *dbias = jcp.with_bias + ? &(diff_bias_prv(ithread, + simd_w * (ofm1 * jcp.oc_block + ofm2))) + : NULL; + diff_dst_transform_bwd_weights_ver(img, jcp, + &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, + 0, 0, 0)), + &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), + dbias); + }); + +PRAGMA_OMP(barrier) + + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) { + parallel_nd_in_omp(alpha, alpha, jcp.nb_oc, + [&](int oj, int oi, int ofm1) { + kernel_->gemm_loop_ker_first_iter( + (float *)&(U(ifm1, ofm1, oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, + 0, 0, 0, 0, 0))); + for (int tile_block = 1; tile_block < jcp.tile_block; + tile_block++) { + kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1, + oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, tile_block, + 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, tile_block, + 0, 0, 0, 0))); + } + }); + } + +PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2) { + diff_weights_transform_bwd_weights(jcp, + &(diff_weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), + &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0))); + }); + + if (jcp.with_bias) { +PRAGMA_OMP(for) + for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) { + for (int ithr = 0; ithr < nthreads; ithr++) { + float* base_bias_ptr = &(diff_bias(ofm1, 0)); + float* base_bias_prv_ptr = &(diff_bias_prv( + ithr * jcp.oc + ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < simd_w; ofm2++) { + base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2]; + } + } + } + } + } + + _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad); +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp new file mode 100644 index 0000000000..6c76f37c72 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp @@ -0,0 +1,318 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_common { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { + const int nthr = mkldnn_get_max_threads(); + + size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr + * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; + scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); + + size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + + size_t padded_bias_sz = + jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; + scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); + } +} +} + +template +struct _jit_avx512_common_convolution_winograd_t { + _jit_avx512_common_convolution_winograd_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); + } + + ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } + + protected: + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_common_convolution_winograd_fwd_t + : _jit_avx512_common_convolution_winograd_t + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_fwd_t(){}; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_data_t + : _jit_avx512_common_convolution_winograd_t, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, nullptr, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, + hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true), kernel_(nullptr) + { + kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + pd()->jcp_); + } + + ~jit_avx512_common_convolution_winograd_bwd_weights_t() + { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); + return status::success; + } + +private: + void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const; + void _maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; +}; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp new file mode 100644 index 0000000000..d4a451c021 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp @@ -0,0 +1,853 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_lrn.hpp" + +#include "jit_generator.hpp" + +#define FWD_RBC 4 +#define BWD_RBC 3 + +#define XMM_SIZE (4*sizeof(float)) +#define ZMM_SIZE (vlen) +#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE) +#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE) +#define SRC_PREV_OFFSET (vlen - XMM_SIZE) + +#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \ + statement;\ +} + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +enum params { vsize = 16, vlen = 64}; + +typedef struct { + const float *src; + float *dst, *ws0, *ws1; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *ws0, *ws1; + float *diff_src; +} jit_args_bwd_t; + +struct nChw16c_across { +/* version: + * -1: channels 0..15, + * 1: channels C-16 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 dst = r8; + Reg64 scratch0 = rdx; + Reg64 scratch1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm zalpha = zmm0; + Xmm xalpha = xmm0; + Zmm zk = zmm1; + Xmm xk = xmm1; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r9; + + int xsrc_prev = 2; + int zsrc = 7; + int xsrc_next = 3; + int zc = 7; + + int za = 2; + int zb = 3; + int zd = 5; + int ze = 6; + int zsum = 4; + int zdst = 2; + int zbase = 3; + int zsum2 = 5; + + prop_kind_t pk; + int use_h_parallelism; + + float alpha, k; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_fwd_t *); + void operator()(jit_args_fwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*FWD_RBC, + prf2_offt = 8*FWD_RBC + }; + + inline void compute_loop(int loop_size_param) + { + // loop_size - param for IRB_LOOP macro + int loop_size = FWD_RBC; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*3 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*7 + i); + }; + + if (!is_first && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen])); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen])); + } + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0, + (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0, + (irb + prf2_offt)*vlen))); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen))); + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1, + (irb + prf0_offt) * vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1, + (irb + prf2_offt) * vlen))); + } + + loop_size = loop_size_param; + if (loop_size == 0) + return; + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_prev), + ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET])); + } + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_next), + ptr[src + (irb + HW) * vlen])); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xsrc_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xsrc_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + + assert(zc == zsrc); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc))); + + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze))); + + IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha)); + + IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum))); + + IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2))); + + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + + if (pk != prop_kind::forward_inference) { + IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen), + zreg(irb, zsum))); + } + IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum))); + IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst))); + if (pk != prop_kind::forward_inference) { + /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ + IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase))); + IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen), + zreg(irb, zsum))); + } + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + prop_kind_t prop_kind, + int use_h_parallel, + float A, + float K, + void *code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , pk(prop_kind) + , use_h_parallelism(use_h_parallel) + , alpha(A) + , k(K) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(dst, ptr[param + 8]); + if (pk != prop_kind::forward_inference) + { + mov(scratch0, ptr[param + 16]); + mov(scratch1, ptr[param + 24]); + } + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == -2; + is_single = J.version == 3; + + W = J.W; + HW = J.W*J.H; + int LSB = use_h_parallelism ? W : HW; + + sub(t, FWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(zalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(zk, xk); + + if (is_first || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2); + } + } + if (is_last || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xmm2); + } + } + + int LSREST = LSB % FWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(FWD_RBC); + + add(src, FWD_RBC*vlen); + add(dst, FWD_RBC*vlen); + if (pk != prop_kind::forward_inference) + { + add(scratch0, FWD_RBC*vlen); + add(scratch1, FWD_RBC*vlen); + } + + for(int irb = 0; irb < FWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST); + + add(t, FWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +status_t jit_avx512_common_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, data_d.data_type()) + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc()->prop_kind == forward_training) { + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + } + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float k = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk, + use_h_parallelism, alpha, k); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk, + use_h_parallelism, alpha, k); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k); + } +} + +jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, c16, C16, h, H); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, c16, C16, h, H); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 diffsrc = r8; + Reg64 diffdst = r9; + Reg64 workspace0 = rdx; + Reg64 workspace1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm znalphabeta = zmm0; + Xmm xnalphabeta = xmm0; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r10; + + int xws1_prev = 1; + int xdiffdst_prev = 2; + int zws1 = 1; + + int zsrc = 1; + int zdiffdst = 5; + int zdiffsrc = 6; + + int xws1_next = 1; + int xdiffdst_next = 3; + + int za = 1; + int zb = 2; + int zd = 3; + int ze = 4; + int zws0 = 2; + + float nalphabeta; + + int use_h_parallelism; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_bwd_t *); + void operator()(jit_args_bwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*BWD_RBC, + prf2_offt = 8*BWD_RBC + }; + + inline void compute_loop(int loop_size_param, int prefetchL1, + int prefetchL2) + { + // loop_size - param for IRB_LOOP macro + int loop_size = loop_size_param; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*6 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*6 + i); + }; + +// ---- prefetching ------------------------------------------- + if (!is_first && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + - 2 * HW) * vlen])); + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + - HW) * vlen])); + } + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen])); + + if (!is_last && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + + 2 * HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt + + 2 * HW) * vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + + HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt + + HW) * vlen])); + } + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen])); +// ----------------------------------------------------------- + + if (loop_size_param == 0) + return; + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb + - 2 * HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb + - HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev), + xreg(irb, xws1_prev))); + } + + IRB_LOOP(vmovups(zreg(irb, zws1), + EVEX_compress_addr(workspace1, irb*vlen))); + IRB_LOOP(vmovups(zreg(irb, zdiffdst), + EVEX_compress_addr(diffdst, irb*vlen))); + IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst), + zreg(irb, zws1))); + + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb + + 2 * HW) * vlen])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb + + HW) * vlen])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next), + xreg(irb, xws1_next))); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xdiffdst_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zdiffsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xdiffdst_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, za))); + assert(zsrc == za); + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zb))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zd))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, ze))); + IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta)); + + IRB_LOOP(vmovups(zreg(irb, zws0), + EVEX_compress_addr(workspace0, irb*vlen))); + IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst), + zreg(irb, zws0))); + IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc), + zreg(irb, zdiffdst))); + + Label unaligned_store, end_store; + test(diffsrc, vlen - 1); + jnz(unaligned_store, T_NEAR); + IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + jmp(end_store, T_NEAR); + L(unaligned_store); { + IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + } + L(end_store); + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2*A*B) + , use_h_parallelism(use_h_parallel) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(diffdst, ptr[param + 8]); + mov(workspace0, ptr[param + 16]); + mov(workspace1, ptr[param + 24]); + mov(diffsrc, ptr[param + 32]); + + W = J.W; + HW = J.H*J.W; + int LSB = this->use_h_parallelism ? W : HW; + + sub(t, BWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(znalphabeta, xnalphabeta); + + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == +2; + is_single = J.version == 3; + + if (is_first || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1); + } + } + if (is_last || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1); + } + } + + int LSREST = LSB % BWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(BWD_RBC, 1, 1); + + add(src, BWD_RBC*vlen); + add(diffsrc, BWD_RBC*vlen); + add(diffdst, BWD_RBC*vlen); + add(workspace0, BWD_RBC*vlen); + add(workspace1, BWD_RBC*vlen); + + for(int irb = 0; irb < BWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1); + + add(t, BWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } + +}; + +status_t jit_avx512_common_lrn_bwd_t::pd_t::init() { + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float beta = pd()->desc()->lrn_beta; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), + alpha, beta, use_h_parallelism); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), + alpha, beta, use_h_parallelism); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism); + } +} + +jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, h, H, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, h, H, c16, C16); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp new file mode 100644 index 0000000000..37fbb9b3e5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP +#define CPU_JIT_AVX512_COMMON_LRN_HPP + +#include "c_types_map.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_fwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_fwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_bwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_bwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp new file mode 100644 index 0000000000..c58d3fa0a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp @@ -0,0 +1,1103 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_2x3.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = + reinterpret_cast(const_cast(getCode())); + } + + void generate(); + + Zmm vreg_inp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(15 - i); + } + + Zmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + +}; + +void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() { + Label ic_block_label; + + const int load_block = 16; + int out_offset = 0, inp_offset = 0; + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, src); + READ_PARAM(reg_aux_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + } + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + for (int y = 0; y < jcp.alpha; y++) { + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm = vreg_inp(y * jcp.alpha + x); + + vxorps(zmm, zmm, zmm); + kandw(r_mask, y_mask, x_mask(x)); + inp_offset = sizeof(float) + * ((-jcp.t_pad + y) * jcp.iw * load_block + + (-jcp.l_pad + x) * load_block); + vmovups(zmm | r_mask, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + } + for (int y = 0; y < jcp.alpha; y++) { + vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0), + vreg_inp(y * jcp.alpha + 2)); + vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 2)); + vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2), + vreg_inp(y * jcp.alpha + 1)); + vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 3)); + } + for (int x = 0; x < jcp.alpha; x++) { + vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0), + vreg_tmp(x + jcp.alpha * 2)); + vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 2)); + vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2), + vreg_tmp(x + jcp.alpha * 1)); + vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 3)); + } + + for (int i = 0; i < 16; i++) { + out_offset = sizeof(float) * (jcp.inp_stride * i); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), + vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block); + add(reg_aux_ptr_dst, sizeof(float) * load_block); + } + dec(reg_ic_block); + cmp(reg_ic_block, 0); + jg(ic_block_label, T_NEAR); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast( + const_cast(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); + return Zmm(31 - id_reg_stg); + } + + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id_reg_out < jcp.alpha * jcp.alpha + 12); + return Zmm(31 - id_reg_out); + } + + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_prev_dst = Zmm(0); + Zmm vreg_bias = Zmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1); + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() { + Label oc_block_label; + + const int load_block = 16; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for (int i = 0; i < 16; i++) { + int internal_offset = sizeof(float) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for (int y = 0; y < jcp.alpha; y++) { + vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1)); + vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2)); + + vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2)); + vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3)); + } + for (int x = 0; x < jcp.m; x++) { + vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1)); + vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2)); + + vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2)); + vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3)); + } + + + if (jcp.with_bias) { + auto bias_addr = ptr [ reg_ptr_bias ]; + vmovups(vreg_bias, bias_addr); + } + for (int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]); + for (int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = sizeof(float) * + (y * jcp.ow * jcp.oc_block + x * jcp.oc_block); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + + if (maybe_relu(0)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + if (p_sum_scale) { // post_op: sum + vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + vmovups(vreg_prev_dst | r_mask, addr); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + + vmovups(addr, zmm | r_mask); + } + } + }; + + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, wino_dst); + READ_PARAM(reg_aux_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha * jcp.alpha; i++) + vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i)); + + for (int i = 0; i < jcp.alpha; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + + int oc_blocks = 1; + oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); + { + loop_body(); + add(reg_aux_ptr_src, sizeof(float) * load_block); + add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, jcp.typesize_bia * load_block); + } + dec(reg_oc_block); + cmp(reg_oc_block, 0); + jg(oc_block_label, T_NEAR); + + sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = reinterpret_cast( + const_cast(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, + memory_desc_t& expect_wei_md); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert (31 - jcp.n2_block * jcp.m_block - i > 1); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + + Reg64 reg_mb = rbx; + Reg64 reg_nnb = rdx; + Reg64 reg_K = rsi; + +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + preamble(); +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); +#undef READ_PARAM + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + + mov(reg_K, jcp.k_chunks); + L(K_loop_label); { + int wei_offset = 0; + for (int _i = 0; _i < jcp.k2_block; _i++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + if (jcp.small_mb) { + int wei_offset = sizeof(float) + * ((nb2 * jcp.nb_ic * jcp.ic_block + * jcp.oc_block) + + _i * jcp.oc_block); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } else { + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, + sizeof(float) * wei_offset)); + wei_offset += jcp.oc_block; + } + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = sizeof(float) * (m * jcp.K + _i); + if (jcp.n2_block > 1) { + vbroadcastss(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2), + vreg_src); + } else { + vfmadd231ps(vreg_out(0, m), vreg_wei(0), + EVEX_compress_addr(reg_aux_src2, inp_offset, true)); + } + } + } + add(reg_aux_src2, sizeof(float) * jcp.ic_block); + if (jcp.small_mb) + add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block); + else + add(reg_aux_wei2, + sizeof(float) * jcp.k2_block * jcp.n2_block + * jcp.oc_block); + } + dec(reg_K); + cmp(reg_K, 0); + jg(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = sizeof(float) * + (m * jcp.N + nb2 * jcp.oc_block); + vmovups(EVEX_compress_addr(reg_aux_dst2,offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K); + add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N); + } + dec(reg_mb); + cmp(reg_mb, 0); + jg(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block); + add(reg_aux_wei, + sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block + * jcp.oc_block); + + dec(reg_nnb); + cmp(reg_nnb, 0); + jg(nnb_loop_label, T_NEAR); + } + postamble(); +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + return jcp.mb >= 4; +} +} + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &wei_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, memory_desc_t &expect_wei_md) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + int simdw = 16; + + format_tag_t dat_tag = format_tag::nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simdw); + jcp.ic = rnd_up(jcp.ic, simdw); + } + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core))) + return status::unimplemented; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (src_d.data_type() != data_type::f32) + return status::unimplemented; + if (wei_d.data_type() != data_type::f32) + return status::unimplemented; + if (dst_d.data_type() != data_type::f32) + return status::unimplemented; + + jcp.ic_block = simdw; + jcp.oc_block = simdw; + + bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad + && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0 + && jcp.l_pad < 2 && jcp.l_pad >= 0; + if (!ok) + return status::unimplemented; + + const int L2_cap = get_cache_size(2, true) / sizeof(float); + const int L3_capacity = get_cache_size(3, false) / sizeof(float); + int a = jcp.alpha; + int aa = a * a; + int mb = jcp.mb; + int ic = jcp.ic; + int oc = jcp.oc; + int ih = jcp.ih; + int iw = jcp.iw; + auto wei_sz = (float)aa * ic * oc; + auto inp_sz = (float)mb * ih * iw * ic; + auto sp_sz = (float)mb * ih * iw; + + /* Heuristics here. Numbers '28','196' is an observation from data. */ + if (wei_sz / inp_sz > 5) + jcp.small_mb = true; + else + jcp.small_mb = false; + + if (mb > nstl::min(jcp.nthr, 28) + || (!jcp.small_mb + && (wei_sz >= 0.9f * L2_cap + || inp_sz > L2_cap * jcp.nthr + L3_capacity)) + || (jcp.small_mb && sp_sz > 196)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + const int skx_free_regs = 30; + + auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block, + int &n2_block, float ®_eff) { + M = (xb * yb) / jcp.alpha; + int max_m_block = m_block = nstl::min(M, skx_free_regs); + int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs); + reg_eff = 0; + for (int im = max_m_block; im > 0; im--) { + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = in2 * im + in2; + float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f; + if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs + || cur_reg_eff <= reg_eff) + continue; + reg_eff = cur_reg_eff; + m_block = im; + n2_block = in2; + } + } + }; + + int oh = jcp.oh; + int ow = jcp.ow; + int nb_oc = jcp.nb_oc; + int Z = ic + oc; + int Y = ic * oc; + const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float); + + /* Selecting xb and yb blocking */ + int min_yb = jcp.alpha; + int min_xb = jcp.alpha; + int max_yb = nstl::max(min_yb, rnd_up(ih, 2)); + int max_xb = nstl::max(min_xb, rnd_up(iw, 2)); + float best_eff = 0.f; + for (int ix = max_xb; ix >= min_xb; ix -= 2) { + if (rnd_up(ow, ix) < iw - 2) + continue; + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + if (rnd_up(oh, iy) < ih - 2) + continue; + int ex_y = rnd_up(oh, iy); + int ex_x = rnd_up(ow, ix); + float work_eff = (float)(ih * iw) / (ex_y * ex_x); + + int M, m_block, n2_b; + float reg_eff, thr_eff, par_eff, mem_eff, req_mem; + + find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff); + + /* outer parallelization */ + int nblocks = mb * div_up(ih, iy) * div_up(iw, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + + mem_eff = 1.f; + req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y; + if (req_mem > L2_cap / 2) { + if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7) + mem_eff /= (n2_b + 1) / 2.f; + else + mem_eff /= (n2_b + 1) / 3.f; + } + + float outer_eff = thr_eff + work_eff + reg_eff + mem_eff; + + /* inner parallelization */ + int bsz = iy * ix / a; + int gemmw = aa * (nb_oc / n2_b); + int bsz_r = rnd_up(bsz, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y); + + req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr)); + int oc_per_thr = + nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr)); + req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z; + if (req_mem > L2_cap) + mem_eff = 0.1f; + par_eff = 1 / (2.f * nblocks); + + float inner_eff = thr_eff + work_eff + mem_eff + par_eff; + + float eff = jcp.small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.M = M; + jcp.m_block = m_block; + jcp.n2_block = n2_b; + } + } + } + + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + assert(jcp.M % jcp.m_block == 0); + assert(jcp.nb_oc % jcp.n2_block == 0); + + jcp.n_chunks = jcp.nb_oc / jcp.n2_block; + jcp.k2_block = jcp.ic_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format + = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = 1.f; + size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::pd_t::jit_conf(memory_desc_t& expect_wei_md) { + return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t:: + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + const size_t wino_size_offset = + (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb); + const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16; + const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get(key_wino_V); + auto ptr_M = scratchpad.get(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = ptr_V + size_wino_src * ithr; + auto wino_dst = ptr_M + size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t + ::call_params_t(); + auto dst_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t + ::call_params_t(); + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw + * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow + * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get(key_wino_V); + auto ptr_M = scratchpad.get(key_wino_M); + + for (int mb = 0; mb < jcp.mb; mb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t :: + call_params_t(); + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = ptr_V + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + gemm_p.src = ptr_V + jcp.inp_stride * tile_ij; + gemm_p.dst = ptr_M + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t :: + call_params_t(); + + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = ptr_M + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp new file mode 100644 index 0000000000..7e38b07f5a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp @@ -0,0 +1,144 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t; +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t; +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t; + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_2x3_fwd_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + memory_desc_t expect_wei_md = *weights_md(); + status_t jit_conf_result = jit_conf(expect_wei_md); + if (jit_conf_result != status::success) return jit_conf_result; + set_default_alg_kind(alg_kind::convolution_winograd); + + if (weights_md_.format_kind == format_kind::any) + weights_md_ = expect_wei_md; + if (weights_md_ != expect_wei_md) + return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(memory_desc_t& expect_wei_md); + + void init_scratchpad() { + using namespace memory_tracking::names; + + auto scratchpad = scratchpad_registry().registrar(); + + int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb; + + size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K); + + size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K); + + if (wants_padded_bias()) { + assert(jcp_.ngroups == 1); + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc); + } + } + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nChw16c, any, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd); + ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + if (pd()->jcp_.small_mb) + execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx)); + + return status::success; + } + +private: + void execute_forward_small_mb(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_; + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_; + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp new file mode 100644 index 0000000000..96325e3ade --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp @@ -0,0 +1,1020 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t +::weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const +{ + float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f, + 1.13777777777778f, 0.430252100840336f, 0.179271708683473f}; + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + float T[alpha][3][simd_w]; + auto p = jit_wino_transform_call_s(); + + p.src = wp; + p.dst = twp; + p.G = G; + p.M = F; + p.Mw = Fw; + p.T = T; + + kernel_->weights_transform_data_ker(&p); +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::output_transform_data +(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = pout_b; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->output_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t +::output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = outp; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block) + * outh * outw * jcp.dimM_simd_block; + + kernel_->output_transform_data_ker(&p); + + tile_index++; + } + } +} + + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t + ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + + p.src = inp; + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t + ::input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + + array_offset_calculator input(inp, + jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); + array_offset_calculator output(tinp, + alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + auto p = jit_wino_transform_call_s(); + + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + float *pinp_b = &(input(img, 0, 0, 0, 0)); + + p.src = pinp_b; + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_index++; + } + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get(key_wino_U); + + array_offset_calculator U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, + 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + }); + + if (jcp.prop_kind != prop_kind::forward_inference) { + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), + (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, + [&](int N_blk1, int oj, int oi, int M_blk1) { + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; + K_blk1++) + for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++) + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, 0, 0, 0)), K_blk1); + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block), + [&](int img, int M_blk1, int M_blk2) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + output_transform_data(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), bias_ptr); + }); + + } +} + +template +void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_SGD( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + array_offset_calculator input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); + array_offset_calculator output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); + array_offset_calculator weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator bias(bias_ptr, + jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get(key_wino_U); + + array_offset_calculator U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + + array_offset_calculator M(is_fwd + ? scratchpad.template get(key_wino_M) + : scratchpad.template get(key_wino_V), + 0, jcp.dimM_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator V(is_fwd + ? scratchpad.template get(key_wino_V) + : scratchpad.template get(key_wino_M), + 0, alpha, alpha, jcp.dimN_block, + jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + if (jcp.prop_kind != prop_kind::forward_inference) { + + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.tile_block, [&](int tile_block) { + int ithr = mkldnn_get_thread_num(); + + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) { + + input_transform_tileblock_data( + tile_block, jcp, + &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + } + } + + for (int oj = 0; oj < alpha; oj++) { + for (int oi = 0; oi < alpha; oi++) { + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) + for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) + kernel_->gemm_loop_ker( + (float *)&(M(ithr, M_blk1, oj, oi, + N_blk, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, K_blk1, + 0, 0, 0, 0)), + (const float *)&(V(ithr, oj, oi, + N_blk, K_blk1, 0, 0, 0)), K_blk1); + } + } + + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) { + for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; + M_blk2++) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform_tileblock_data(tile_block, jcp, p_ops, + &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(0, M_blk, 0, 0, 0)), bias_ptr); + } + } + }); +} + +template struct _jit_avx512_core_fp32_wino_conv_4x3_t; +template struct _jit_avx512_core_fp32_wino_conv_4x3_t; + +namespace { + +void subarray_sum(size_t num_arrs, float *output, size_t nelems, + float *input_ptrs[], size_t input_starts[], size_t input_ends[]) { + using namespace nstl; + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const int ithr = mkldnn_get_thread_num(); + const int nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} + +const int max_threads_number = 1024; + +// Sum to the first buffer array +void array_sum(size_t num_arrs, float *output, + size_t nelems, float *input_ptrs[], bool reduce_to_first = true) { + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const size_t ithr = mkldnn_get_thread_num(); + const size_t nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} +} //bwdw namespace + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_SDGtWo(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights(ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator Us(scratchpad.get(key_wino_U), + 0, alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc + * jcp.ic / jcp.nb_ic; + array_offset_calculatordiff_weights_prv( + scratchpad.get(key_wino_U) + U_sz, + 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator M(scratchpad.get(key_wino_M), + 0, alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator V(scratchpad.get(key_wino_V), + 0, alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); + + auto trans_ker_p = jit_wino_transform_call_s(); + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f, + 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f, + 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + +PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc / simd_w, + [&](int ithr, int ofm){ + float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + pdbias[v] = 0.0f; + } + }); + } + + int ithr = mkldnn_get_thread_num(); + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) { + int first_tblk = 0; +PRAGMA_OMP(for) + for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) { + int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur; + int img = tile_index / (jcp.itiles * jcp.jtiles); + trans_ker_p.ti = tile_index % jcp.itiles; + trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles; + trans_ker_p.M = I; + trans_ker_p.T = T; + trans_ker_p.G = G_I_3x3_4x4; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + } + + for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) { + trans_ker_p.G = G_W_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias && ifm1 == 0) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + } + + for (int oj = 0; oj < alpha; ++oj) { + for (int oi = 0; oi < alpha; ++oi) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(M(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(V(ithr, oj, oi, 0, 0, 0, 0))); + } + } + trans_ker_p.G = G_O_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block + + ofm3; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(Us(ithr, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights_prv(ithr, + ofm, ifm, 0, 0, 0, 0)); + if (first_tblk == 0) { + kernel_->diff_weights_transform(&trans_ker_p); + } else { + kernel_->diff_weights_transform_accum(&trans_ker_p); + } + } + } + } + } + ++first_tblk; + } + } +} + + // Reduce diff-weights + { + float *output = ptr_diff_weights; + float *input_base = scratchpad.get(key_wino_U) + U_sz; + int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + nelems * i; + } + array_sum(nthreads, output, nelems, input_ptrs, false); + + if (jcp.with_bias) { + output = ptr_diff_bias; + input_base = scratchpad.get(key_conv_bia_reduction); + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + jcp.oc * i; + } + array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs, + false); + } + } +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_S_D_Giot_W(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator diff_weights((float *)ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator diff_bias((float *)ptr_diff_bias, jcp.oc); + + array_offset_calculator U(scratchpad.get(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_size = jcp.oc * jcp.ic * alpha * alpha; + array_offset_calculator Us( + scratchpad.get(key_wino_U) + U_size, + 0, jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator M(scratchpad.get(key_wino_M), + jcp.nb_oc, + jcp.tile_block, + alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur , + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator V(scratchpad.get(key_wino_V), + jcp.nb_ic, + jcp.tile_block, + alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator diff_bias_prv( + scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); + + size_t input_starts[max_threads_number] = {0}; + size_t input_ends[max_threads_number] = {0}; + size_t first_tblk = 0; + + auto trans_ker_p = jit_wino_transform_call_s(); + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, + 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, + 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + +PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + } + + trans_ker_p.G = G_I_3x3_4x4; + trans_ker_p.M = I; + trans_ker_p.T = T; + + parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb, + [&](int ifm1, int ifm2, int img){ + size_t ifm = ifm1 * jcp.ic_block + ifm2; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + }); + + int ithr = mkldnn_get_thread_num(); + trans_ker_p.G = G_W_3x3_4x4; + parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb, + [&](int ofm1, int ofm2, int img){ + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + }); + + PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block, + [&](int ifm1, int ofm1, int oj, int oi, int tblk1){ + if (first_tblk == 0) { + input_starts[ithr] = + (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, + 0, 0)) + - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0, + 0, 0, 0)); + input_ends[ithr] = input_starts[ithr] + + jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + else if (tblk1 == 0) { + input_ends[ithr] += jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + + if (first_tblk == 0 || tblk1 == 0) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } else { + kernel_->gemm_loop_ker( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } + ++first_tblk; + }); +} + + // Reduce diff-weights + { + float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0)); + size_t nelems = jcp.ic * jcp.oc * alpha * alpha; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) + input_ptrs[i] = output + nelems * (i + 1); + subarray_sum(nthreads, output, nelems, input_ptrs, + input_starts, input_ends); + } + + trans_ker_p.G = G_O_3x3_4x4; +PRAGMA_OMP(parallel firstprivate(trans_ker_p)) + { + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){ + int ofm = (ofm1 * jcp.oc_block + ofm2) + * jcp.oc_reg_block + ofm3; + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm, + 0, 0, 0, 0)); + kernel_->diff_weights_transform(&trans_ker_p); + }); + } + + if (jcp.with_bias) { + parallel_nd(jcp.oc / simd_w, [&](int ofm1) { + float* pbias = &(diff_bias(ofm1 * simd_w)); + float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w)); + + const int blk_sz = ofm1 == jcp.oc / simd_w - 1 + ? jcp.oc_without_padding - ofm1 * simd_w : simd_w; + + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] = pbias_prv[ofm2]; + } + + for (int ithr = 1; ithr < nthreads; ++ithr) { + pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] += pbias_prv[ofm2]; + } + } + }); + } +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp new file mode 100644 index 0000000000..f1a56aac70 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp @@ -0,0 +1,386 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_core { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace utils; + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles + * jcp.jtiles; + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles + * jcp.jtiles; + + switch (jcp.sched_policy) { + case WSCHED_DATA_W_SGD: + V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.ic; + M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.oc; + break; + case WSCHED_WEI_SDGtWo: + U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc + * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw); + M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.oc / jcp.nb_oc); + V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.ic / jcp.nb_ic); + break; + case WSCHED_WEI_S_D_Giot_W: + U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc; + M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles; + V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles; + break; + default: break; + } + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) { + size_t br_sz = (size_t)jcp.nthr * jcp.oc; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + } +} +} + +template +struct _jit_avx512_core_fp32_wino_conv_4x3_t { + + _jit_avx512_core_fp32_wino_conv_4x3_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp); + } + + ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; } + + protected: + void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const; + void input_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void output_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, + float *bias) const; + void output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const; + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t + : _jit_avx512_core_fp32_wino_conv_4x3_t + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = desc()->prop_kind == prop_kind::forward_training + ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + default: + break; + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t + : _jit_avx512_core_fp32_wino_conv_4x3_t, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_data + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + default: + break; + } + + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) + return status::unimplemented; + + status_t status = + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , kernel_(nullptr) + { + kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + pd()->jcp_); + } + + ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t() + { + delete kernel_; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + switch (kernel_->jcp.sched_policy) { + case WSCHED_WEI_SDGtWo: + _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + case WSCHED_WEI_S_D_Giot_W: + _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + default: + assert(kernel_->jcp.sched_policy != WSCHED_INVALID); + break; + } + return status::success; + } + +private: + void _execute_backward_weights_SDGtWo(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_backward_weights_S_D_Giot_W(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp new file mode 100644 index 0000000000..0d64a2d13a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp @@ -0,0 +1,2596 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + /* Determines if current winograd implementation is faster than direct. + Following conditions are empirical and based on performance data */ + unsigned int ncores_per_socket = + cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + unsigned int nthreads = mkldnn_get_max_threads(); + + if (jcp.prop_kind == prop_kind::forward_inference) { + return jcp.mb >= 4; + } else if (nthreads > ncores_per_socket) { + double src_dst_transforms_per_core = alpha * alpha + * (jcp.ic + jcp.oc) + * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size) + * ((jcp.ow + tile_size - 1) / tile_size) + * sizeof(float) / 1024. / 1024. / nthreads; + double wei_transform = alpha * alpha + * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.; + + if (jcp.prop_kind == prop_kind::backward_weights) { + if (src_dst_transforms_per_core < 0.3 + || (src_dst_transforms_per_core <= 28 && wei_transform < 4)) + return false; + else + return true; + } else { + if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02) + return false; + } + } + + return jcp.mb > 8; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: cache_latency = 250; break; + case L3: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, + int dimN_block, float C2_min, float C2_max) { + float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic) + * dimN_block * jcp.dimN_reg_block + + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float); + float L2_lb = C2_min * L2_cache_size; + float L2_ub = C2_max * L2_cache_size; + return (block_size > L2_lb && block_size < L2_ub); +} + +bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block, + int dimM_block, float C1_min, float C1_max) { + float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block + * jcp.dimK_reg_block * jcp.dimM_reg_block + + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block + + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block) + * (float)sizeof(float); + float L1_lb = C1_min * L1_cache_size; + float L1_ub = C1_max * L1_cache_size; + return (gemm_block_size > L1_lb && gemm_block_size < L1_ub); +} +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block + * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block, + int dimM_simd_block, float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} + +bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block, + int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) +{ + float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK + * (float)sizeof(float); + float B_size = dimN_block * dimN_reg_block * dimK + * (float)sizeof(float); + return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate() +{ + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block; + // dimM_reg_block++) // unrolled + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) // unrolled + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][dimM_reg_block][tile] += + // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block] + // * broadcast(B[dimK_block][tile][dimK_reg_block]); + // Notes: + // jcp.kernel_kind defines embedded or explicit broadcast + // dimM_reg_block=1 for embedded bcast kernel + + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + auto zmm_srcB = [=](int tile) { + int idx = 1 + tile; + assert(idx < 1 + jcp.dimN_reg_block); + return Xbyak::Zmm(idx); + }; + auto zmm_dstC = [=](int dimM_reg_block, int tile) { + int idx{0}; + if (jcp.kernel_kind == embd_bcast) + idx = 1 + tile; + else + idx = 1 + jcp.dimN_reg_block + + dimM_reg_block * jcp.dimN_reg_block + tile; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + + auto prepare_output = [=]() { + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block, tile); + vpxord(zmm, zmm, zmm); + } + } + }; + auto store_output = [=](bool output_is_aligned) { + Label save; + cmp(reg_is_beta_zero, 0); + je(save, T_NEAR); + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset)); + } + } + + L(save); + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + + // In W_SGD, output will be reused. + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && jcp.sched_policy == WSCHED_DATA_W_S_G_D + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm); + else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + + prepare_output(); + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block ++) { + + if (jcp.kernel_kind == expl_bcast) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + vbroadcastss(zmm_srcB(tile), + ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]); + } + } + + /* Performing the fmas */ + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + + vmovups(zmm_srcA(), + zword[reg_srcA + + jcp.dimK_reg_block * jcp.dimK_block * 64 + * dimM_reg_block + + dimK_reg_block * 64] + ); + + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + zmm_srcB(tile)); + else + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, true)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64); + if (jcp.kernel_kind == expl_bcast) { + add(reg_srcA, + (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64 + * jcp.dimK_block); + } + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::weights_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int kh = jcp.kh; + int kw = jcp.kw; + + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_zero = Xbyak::Zmm(30); + + auto zmm_M = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_MT = [=](int i) { + return Xbyak::Zmm(i + simd_w); + }; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_F = [=](int i) { + return Xbyak::Zmm(alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(alpha + 3 + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(2 * alpha + 3 + i); + }; + + auto zmm_load = [=](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(wreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < alpha; i++) { + vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]); + } + vpxord(zmm_zero, zmm_zero, zmm_zero); + }; + + auto trans16x16 = [=]() { + for (int i = 0; i < simd_w; i+=2 ) { + vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]); + vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]); + vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1)); + vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1)); + } + for (int i = 0; i < simd_w; i+=4 ) { + vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2)); + vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2)); + vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3)); + vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3)); + } + for (int i = 0; i < simd_w; i += 8) { + vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88); + vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88); + vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88); + vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88); + vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd); + vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd); + vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd); + vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd); + } + { + int i = 0; + int mask = 0x88; + vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0)); + vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1)); + vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2)); + vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3)); + vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4)); + vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5)); + vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6)); + vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7)); + mask = 0xdd; + vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8)); + vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9)); + vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10)); + vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11)); + vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12)); + vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13)); + vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14)); + vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15)); + } + }; + + auto load_src = [=]() { + mov(wreg_src, ptr[param1 + GET_OFF(src)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + if (is_fwd) { + for (int v1 = 0; v1 < simd_w; v1++) { + int offset_src = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + vmovups(zmm_temp, ptr[wreg_src + offset_src]); + vmovups(ptr[wreg_F + offset_F], zmm_temp); + } + } else { + int offset_src = ((2 - j) * kw * simd_w * simd_w + + (2 - i) * simd_w * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w) * typesize; + lea(wreg_M, ptr[wreg_src + offset_src]); + lea(wreg_MT, ptr[wreg_F + offset_F]); + trans16x16(); + } + } + } + }; + + auto store_dst = [=]() { + mov(wreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + mov(wreg_dst_aux, wreg_dst); + mov(wreg_Fw_aux, wreg_Fw); + + int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimK_block * simd_w * simd_w; + + L(Loop_j); + { + for (int i = 0; i < alpha; i++) { + // touch pages + vmovups(zmm_load(0), ptr[wreg_Fw_aux + + (i * simd_w * simd_w) * typesize]); + mov(wreg_dst_idx, i * dim5 * typesize); + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0)); + } + for (int i = 0; i < alpha; i++) { + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_Fw = (i * simd_w * simd_w + v1 * simd_w) + * typesize; + vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]); + } + mov(wreg_dst_idx, i * dim5 * typesize); + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_dst = v1 * simd_w * typesize; + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst], + zmm_load(v1)); + } + } + add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize); + add(wreg_dst_aux, alpha * dim5 * typesize); + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, alpha); + jl(Loop_j, T_NEAR); + } + }; + + auto trans_W_4x4_3x3 = [=]() { + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmovups(dst, a); + vfmadd231ps(dst, b, c); + }; + auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, b, c); + vsubps(dst, a, zmm_temp); + }; + auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vsubps(dst, zmm_zero, a); + vfnmadd231ps(dst, b, c); + }; + + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + mov(wreg_T, ptr[param1 + GET_OFF(T)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + L(Loop_j); + mov(wreg_F_aux, wreg_F); + mov(wreg_Fw_aux, wreg_Fw); + mov(wreg_temp, wreg_cnt_j); + shl(wreg_temp, 4 + 2); + lea(wreg_F_aux, ptr[wreg_F + wreg_temp]); + lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]); + + for (int i = 0; i < 3; i++) { + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w + * simd_w + i * simd_w * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_F(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0)); + + vmulps(zmm_T(0), zmm_G(3), zmm_F(0)); + fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1)); + fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1)); + vmovaps(zmm_T(5), zmm_F(2)); + + for (int idx = 0; idx < 6; idx ++) { + vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w) + * typesize], zmm_T(idx)); + } + } + for (int i = 0; i < 6; i++) { + + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_T(idx), ptr[wreg_T + + (i * 3 * simd_w + idx * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_T(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0)); + + vmulps(zmm_F(0), zmm_G(3), zmm_T(0)); + fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1)); + fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1)); + vmovaps(zmm_F(5), zmm_T(2)); + + for (int l = 0; l < 6; l++) { + vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w + + l * simd_w * simd_w) * typesize], zmm_F(l)); + } + } + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, 16); + jl(Loop_j, T_NEAR); + }; + + auto inner_loops = [=]() { + load_src(); + init_G(); + trans_W_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::output_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + bool with_bias = jcp.with_bias; + bool with_relu = jcp.with_eltwise; + bool with_relu_postsum = jcp.with_relu_postsum; + bool with_sum = jcp.with_sum; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_O = [=](int i) { + return Xbyak::Zmm(1 + alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + 2 * alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + 3 * alpha + i); + }; + + auto init_G = [=]() { + mov(oreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < 6; i++) { + vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_src, ptr[param1 + GET_OFF(src)]); + + mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur, + (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_nb_tile_block_ur); + + mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(oreg_tile_block_ur, oreg_tile_block_ur, + jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block_ur); + + if (not_tiled) { + mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(oreg_tile_block, oreg_tile_block, + jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block + * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block); + } + + int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int j_base_offset = j * alpha * last4dim; + int i_base_offset = i * last4dim; + vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]); + vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_dst = [=]() { + vpxord(zmm_zero, zmm_zero, zmm_zero); + mov(oreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(oreg_ydim, 2); // tj * tile_size (==4) + mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(oreg_xdim, 2); // ti * tilesize (==4) + + if (with_bias) + mov(oreg_bias, ptr[param1 + GET_OFF(bias)]); + + auto store_one = [=](int j, int i, bool is_aligned) { + auto zmm_O = Xbyak::Zmm(31); + auto zmm_relu_ns = Xbyak::Zmm(30); + auto xmm_relu_ns = Xbyak::Xmm(30); + int offset = (j * tile_size * simd_w + i * simd_w) * typesize; + + vmovups(zmm_O, ptr[oreg_O + offset]); + if (is_fwd) { + if (with_bias) { + vaddps(zmm_O, zmm_O, ptr[oreg_bias]); + } + if (with_relu) { + if (jcp.eltwise.alpha == 0) { + vmaxps(zmm_O, zmm_O, zmm_zero); + } else { + Opmask kmask = Opmask(7); + mov(imm_addr64, float2int(jcp.eltwise.alpha)); + vmovq(xmm_relu_ns, imm_addr64); + vbroadcastss(zmm_relu_ns, xmm_relu_ns); + vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os); + vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns); + } + } + } + if (with_sum) { + vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]); + if (with_relu_postsum) // orig: with_relu_postsum + vmaxps(zmm_O, zmm_O, zmm_zero); + } + if (is_aligned) + vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O); + else + vmovups(ptr[oreg_out_j + oreg_temp], zmm_O); + }; + + auto i_loop = [=](int j, bool is_aligned) { + for (int i = 0; i < tile_size; i++) { + Label next; + mov(oreg_temp, oreg_xdim); + add(oreg_temp, i); + cmp(oreg_temp, outw); + jge(next, T_NEAR); + shl(oreg_temp, 4 + 2); // * 16 * 4 + + store_one(j, i, is_aligned); + + L(next); + } + }; + + + for (int j = 0; j < tile_size; j++) { + Label next, unaligned; + mov(oreg_temp, oreg_ydim); + add(oreg_temp, j); + cmp(oreg_temp, outh); + jge(next, T_NEAR); + + mov(oreg_out_j, oreg_dst); + imul(oreg_temp, oreg_temp, outw * simd_w * typesize); + add(oreg_out_j, oreg_temp); + + test(oreg_dst, 63); + jnz(unaligned, T_NEAR); + + i_loop(j, true); + jmp(next, T_NEAR); + + L(unaligned); + i_loop(j, false); + + L(next); + } + }; + + auto trans_O_4x4_3x3 = [=]() { + auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){ + vmulps(dst, v1, u1); + vfmadd231ps(dst, v2, u2); + }; + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_T, ptr[param1 + GET_OFF(T)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + + vaddps(zmm_t(0), zmm_O(1), zmm_O(2)); + vaddps(zmm_t(1), zmm_O(3), zmm_O(4)); + vsubps(zmm_t(2), zmm_O(1), zmm_O(2)); + vsubps(zmm_t(3), zmm_O(3), zmm_O(4)); + + vaddps(zmm_T(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_T(0), zmm_T(0), zmm_O(0)); + fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_T(3), zmm_T(3), zmm_O(5)); + + for (int j = 0; j < tile_size; j++) { + vmovups(ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize], zmm_T(j)); + } + } + for (int j = 0; j < tile_size; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + vaddps(zmm_t(0), zmm_T(1), zmm_T(2)); + vaddps(zmm_t(1), zmm_T(3), zmm_T(4)); + vsubps(zmm_t(2), zmm_T(1), zmm_T(2)); + vsubps(zmm_t(3), zmm_T(3), zmm_T(4)); + + vaddps(zmm_O(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_O(0), zmm_O(0), zmm_T(0)); + fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_O(3), zmm_O(3), zmm_T(5)); + + for (int i = 0; i < tile_size; i++) { + vmovups(ptr[oreg_O + (j * tile_size * simd_w + + i * simd_w) * typesize], zmm_O(i)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_O_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::input_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int inpw = is_fwd ? jcp.iw : jcp.ow; + int inph = is_fwd ? jcp.ih : jcp.oh; + int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + int wp_max = inpw + l_pad; + int hp_max = inph + t_pad; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + int G_size = 9; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(1 + G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(ireg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + xor_(ireg_zero, ireg_zero); + vpxord(zmm_zero, zmm_zero, zmm_zero); + + mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(ireg_ydim, 2); // tj * tile_size (==4) + mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(ireg_xdim, 2); // ti * tilesize (==4) + + for (int j = 0; j < alpha; j++) { + mov(ireg_temp, ireg_ydim); + add(ireg_temp, j); + + mov(ireg_mask_j, 0xffff); + cmp(ireg_temp, t_pad); + cmovl(ireg_mask_j, ireg_zero); + cmp(ireg_temp, hp_max); + cmovge(ireg_mask_j, ireg_zero); + + sub(ireg_temp, t_pad); + imul(ireg_temp, ireg_temp, inpw * simd_w * typesize); + mov(ireg_inp_j, ireg_src); + add(ireg_inp_j, ireg_temp); + + for (int i = 0; i < alpha; i++) { + + mov(ireg_temp, ireg_xdim); + add(ireg_temp, i); + + mov(ireg_mask, 0xffff); + cmp(ireg_temp, l_pad); + cmovl(ireg_mask, ireg_zero); + cmp(ireg_temp, wp_max); + cmovge(ireg_mask, ireg_zero); + and_(ireg_mask, ireg_mask_j); + + sub(ireg_temp, l_pad); + shl(ireg_temp, 4 + 2); + + vpxord(zmm_temp, zmm_temp, zmm_temp); + Opmask kmask = Opmask(7); + kmovw(kmask, ireg_mask_32); + vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]); + vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_Iw = [=]() { + + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); + + bool streamout + = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_data_size + ? true : false; + + if (not_tiled) { + mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(ireg_tile_block, ireg_tile_block, + alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize); + } + + mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur, + jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block + * jcp.dimK_reg_block * typesize); + + mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(ireg_tile_block_ur, ireg_tile_block_ur, + jcp.dimK_reg_block * typesize); + + add(ireg_output, ireg_nb_tile_block_ur); + add(ireg_output, ireg_tile_block_ur); + if (not_tiled) + add(ireg_output, ireg_tile_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w + + i * simd_w) * typesize]); + + int j_base_offset = + j * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int i_base_offset = + i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + + if (not_tiled && streamout) + vmovntps(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + else + vmovups(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + } + } + }; + + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, a, b); + vaddps(dst, zmm_temp, c); + }; + + auto trans_I_4x4_3x3 = [=]() { + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_T, ptr[param1 + GET_OFF(T)]); + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w + + i * simd_w) * typesize]); + int j_base_offset = + i * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int idx_base_offset = + idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w) + * typesize],zmm_T(idx)); + } + } + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx + * simd_w) * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w) + * typesize],zmm_I(idx)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_I_4x4_3x3(); + store_Iw(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + if (!mayiuse(avx512_core)) { + return status::unimplemented; + } + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ver = ver_avx512_core; + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Checking conditions not supported by these kernels + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) { + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && (one_of(weights_d.format_kind(), + format_kind::any, format_kind::wino) + || (jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0])); + if (!layout_consistency) + return status::unimplemented; + + return status::success; +} + +void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) { + + /* ----------- dimM reg block ---------------------*/ + auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_reg_block, int current_best) { + int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4; + return (dimM_reg_block >= 1) + && (dimM_reg_block <= max_dimM_reg_block ) + && (dimM_reg_block > current_best); + }; + jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block); + + /* ----------- dimN reg block ---------------------*/ + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return jcp.kernel_kind == embd_bcast + ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best + : dimN_reg_block >= 1 + && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block) + < jcp.nb_reg + && dimN_reg_block > current_best; + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimN, 1, test_cond_dimN_reg_block); +} + +status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { + if (jcp.ver != ver_avx512_core) + return status::unimplemented; + + jcp.kernel_kind = embd_bcast; + + set_kernel_dims_reg_block(jcp); + + /*-------------- L2 blocking for dimN block ---------*/ + + auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_block, int current_best) { + return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) + && (dimN_block > current_best) + && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) + >= 1.5 * mkldnn_get_max_threads()); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block); + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + + if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) + && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) { + + /* ------------------- L1 blocking for GEMM --------------*/ + /* -------------------- Choose dimK block ----------------*/ + + auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp, + int dimK_block, int current_best) { + return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block); + + if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) { + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + + /* -------------- Choose dimM block -------------------*/ + auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_block, int current_best) { + return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, + 0.2, 0.5) && (dimM_block > current_best); + }; + + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, + test_cond_dimM_block); + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + + jcp.sched_policy = WSCHED_DATA_W_SGD; + return status::success; + } + + } + return status::unimplemented; +} + +void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) { + + set_kernel_dims_reg_block(jcp); + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_reg_block, + jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_bis_dimM_block); + jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block + * jcp.dimM_reg_block); + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_reg_block, jcp.dimM_simd_block, .9f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); +} + +status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) { + + jcp.kernel_kind = expl_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK, + .1f, .35f))) { + jcp.kernel_kind = embd_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + } + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.nb_reg = 32; + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + jcp.sched_policy = WSCHED_INVALID; + + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + if (jcp.kernel_kind == embd_bcast) { + jcp.dimM_reg_block = 1; + } + + if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success)) + set_wsched_DATA_W_S_G_D_avx512_core(jcp); + + assert(jcp.sched_policy != WSCHED_INVALID); + return status::success; +} + +bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) + || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_t &src_md, memory_desc_t &weights_md, + const memory_desc_t &dst_md, const primitive_attr_t &attr) { + + status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.ic_reg_block = 1; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + if (cd.prop_kind == mkldnn_forward_inference) { + memory_desc_t expect_wei_md = weights_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_OBaaIBOIio; + wd.r = 3; + wd.alpha = 6; + + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.dimK_reg_block; + wd.oc_block = jcp.dimM_simd_block; + wd.ic2_block = jcp.dimK_block; + wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block; + size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + wd.adj_scale = 1.f; + + if (weights_md.format_kind == format_kind::any) + weights_md = expect_wei_md; + if (weights_md != expect_wei_md) + return status::unimplemented; + } + + return res; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.ic_reg_block = jcp.dimM_reg_block; + jcp.oc_reg_block = 1; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + return res; +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +src_transform_generate() { + constexpr int G_size = 9; + const size_t ifwp = jcp.iw + jcp.l_pad; + const size_t ifhp = jcp.ih + jcp.t_pad; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + xor_(reg_zero, reg_zero); + + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + + for (int j = 0; j < alpha; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.t_pad); + cmovl(reg_maskj, reg_zero); + cmp(reg_ydim, ifhp); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad + imul(reg_src_offset, reg_src_offset, jcp.iw); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + sub(reg_src_offset, jcp.l_pad); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < alpha; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.l_pad); + cmovl(reg_maski, reg_zero); + cmp(reg_xdim, ifwp); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + auto zmm_src = Xbyak::Zmm(31); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src, zmm_src, zmm_src); + vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]); + vmovups(ptr[reg_I], zmm_src); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + add(reg_I, simd_w * typesize); + } + add(reg_ydim, 1); + } + }; + + auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) { + vmovups(dst, c); + vfmadd231ps(dst, a, b); + }; + + auto trans_I_3x3_4x4 = [=]() { + //Use 24 registers + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + mov(reg_T, ptr[reg_transp + GET_OFF(T)]); + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + size_t I_off = (j * alpha + i) * simd_w * typesize; + vmovups(zmm_I(j), ptr[reg_I + I_off]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int j = 0; j < alpha; j++) { + vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize], + zmm_T(j)); + } + + } + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int i = 0; i < alpha; i++) { + size_t dst_off = (j * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur) + * simd_w * typesize; + vmovups(ptr[reg_dst + dst_off], zmm_I(i)); + } + } + }; + + auto compute_transform_SDGtWo = [=]() { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + xor_(reg_tile_count, reg_tile_count); + Label loop_mb, loop_jtiles, loop_itiles, done; + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(done); + + add(reg_dst, simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize); + jmp(loop_mb); + } + L(done); + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, simd_w * typesize); + add(reg_dst, reg_temp); + + Label loop_jtiles, loop_itiles, next_tile_block, next_tile; + L(loop_jtiles); + + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * simd_w * typesize); + size_t tblk_off = alpha * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + * simd_w * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + }; + + preamble(); + init_G(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + compute_transform_SDGtWo(); + else + compute_transform(); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_dst_transform_generate(bool with_bias) { + + constexpr int G_size = 8; + auto zmm_G = [](int i) { + return Xbyak::Zmm(31); + }; + + auto zmm_src = [=](int j, int i) { + return Xbyak::Zmm(G_size + j * 4 + i); + }; + + auto zmm_bias = Xbyak::Zmm(31); + + auto load_src = [=]() { + if (with_bias) vmovups(zmm_bias, ptr[reg_bias]); + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + for (int j = 0; j < tile_size; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.oh); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + imul(reg_src_offset, reg_src_offset, jcp.ow); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < tile_size; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.ow); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i)); + vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]); + if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias, + ptr[reg_src + reg_src_offset]); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + } + add(reg_ydim, 1); + } + if(with_bias) vmovups(ptr[reg_bias], zmm_bias); + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 16 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(j * 4 + i); + }; + + auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) { + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + vmovups(ptr[reg_dst + dst_off], a); + else + vmovntps(ptr[reg_dst + dst_off], a); + }; + + auto trans_W_3x3_4x4 = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < tile_size; i++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7)); + vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4)); + vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4)); + vmovups(zmm_T(5, i), zmm_src(3, i)); + } + + for (int j = 0; j < alpha; j++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7)); + vsubps(zmm_t(5), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(1), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(6), zmm_t(2), zmm_t(4)); + vsubps(zmm_t(2), zmm_t(2), zmm_t(4)); + vmovups(zmm_t(3), zmm_T(j, 3)); + + int alpha_offset = (jcp.oc / jcp.nb_oc) + * (jcp.ntiles / jcp.tile_block) * typesize; + int dst_off = j * alpha * alpha_offset; + movps(reg_dst, dst_off, zmm_t(0)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(5)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(1)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(6)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(2)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(3)); + } + + }; + auto compute_transform_SDGtWo = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done; + + L(loop_oc_ur); + { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + xor_(reg_tile_count, reg_tile_count); + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(tiles_done); + + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize); + jmp(loop_mb); + } + + L(tiles_done); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile; + + L(loop_oc_ur); + { + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * jcp.oc_reg_block * simd_w * typesize); + int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc) + * (jcp.ntiles/jcp.tile_block) * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + preamble(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + compute_transform_SDGtWo(); + } else { + compute_transform(); + } + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_weights_transform_generate(bool first_tile) { + int G_size = 4; + + auto zmm_G = [](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + }; + + auto zmm_src = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto load_src = [=](int i) { + for (int j = 0; j < alpha; j++) { + size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block + * jcp.ic_block * simd_w * simd_w * typesize; + size_t src_off = (j * alpha + i) * alpha_offset; + vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off)); + } + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 6 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i); + }; + + auto zmm_dst = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto zmm_temp = Xbyak::Zmm(31); + + auto store_dst = [=](int j) { + for (int i = 0; i < jcp.kw; i++) { + size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize; + + if (!first_tile) { + vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off)); + vaddps(zmm_dst(i), zmm_dst(i), zmm_temp); + } + vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i)); + } + }; + + auto compute_transform = [=] () { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + + xor_(reg_ic_simd, reg_ic_simd); + Label loop_ic_simd; + L(loop_ic_simd); + { + for (int i = 0; i < alpha; i++) { + load_src(i); + + vaddps(zmm_t(0), zmm_src(1), zmm_src(2)); + vaddps(zmm_t(1), zmm_src(3), zmm_src(4)); + vmovups(zmm_t(2), zmm_src(5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0)); + vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1)); + vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2)); + vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1)); + vsubps(zmm_temp, zmm_src(3), zmm_src(4)); + vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2)); + vmovups(zmm_T(2, i), zmm_t(2)); + vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3)); + } + + for (int j = 0; j < jcp.kh; j++) { + vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2)); + vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4)); + vmovups(zmm_t(2), zmm_T(j, 5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0)); + vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1)); + vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2)); + vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1)); + vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4)); + vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2)); + vmovups(zmm_dst(2), zmm_t(2)); + vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3)); + + store_dst(j); + } + + add(reg_src, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, simd_w * typesize); + add(reg_ic_simd, 1); + cmp(reg_ic_simd, simd_w); + jl(loop_ic_simd); + } + }; + preamble(); + push(reg_EVEX_max_8b_offt); + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + init_G(); + compute_transform(); + pop(reg_EVEX_max_8b_offt); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate( + bool is_first_tile) +{ + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + + auto zmm_srcB = [=] (size_t N_ur){ + return Xbyak::Zmm(N_ur + 1); + }; + + auto broadcastB = [=](size_t K_ur) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast) + * sizeof(float); + vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off)); + } + }; + + auto load_srcA = [=] (size_t K_ur, int M_ur) { + size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_ur * jcp.dimM_simd_block) * sizeof(float); + vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off)); + }; + + auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){ + size_t idx = 1 // zmm_srcA + + jcp.dimN_bcast_ur // zmm_srcB + + M_reg_ur * jcp.dimN_bcast_ur + N_bcast; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + auto prepare_accumm = [=](){ + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + Zmm zmm = zmm_dstC(M_reg_ur, N_bcast); + vpxord(zmm, zmm, zmm); + } + } + }; + + auto store_dstC = [=](){ + /******** Write C back to memory *******/ + for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) { + for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) { + Zmm zmm = zmm_dstC(M_reg, N_ur); + size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_reg * jcp.dimM_simd_block) * sizeof(float); + if (!is_first_tile) { + vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off)); + vaddps(zmm, zmm, Xbyak::Zmm(0)); + } + vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur; + + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + { /************* OC_block (M) loop ***********/ + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + { /*************** IC_block (N) loop *********/ + + mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur); + L(dimN_bcast_ur); + { + prepare_accumm(); + + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + { + /************* nb_tile_ur(K) loop ********/ + for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) { + + broadcastB(K_ur); + + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + load_srcA(K_ur, M_reg_ur); + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) { + vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(), + zmm_srcB(N_bcast)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + store_dstC(); + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float)); + add(reg_dstC, jcp.dimN_bcast_ur + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_nb_dimN_bcast_ur, 1); + jnz(dimN_bcast_ur); + } + + sub(reg_srcB, jcp.dimN_reg_block * sizeof(float)); + add(reg_srcB, jcp.dimK_block + * jcp.dimK_reg_block + * jcp.dimN_reg_block * sizeof(float)); + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + + sub(reg_srcB, jcp.dimN_block + * jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { + +void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) { +/*M params*/ + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + /*N params*/ + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /*K params*/ + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; +} + +status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { + + size_t K_blk_ur, N_blk, M_blk; + /* IS this strategy feasible? */ + auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { + size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); + size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size) + && (jcp.dimK / nthreads >= 1.0); + }; + + auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, + int max_block=1) { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); + size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance=true; + if (!(jcp.dimK % nthreads)) { + load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0); + } + return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) + && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) + && load_balance + && (M_L2_block < L2_cache_size); + }; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&](){ + size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * K_blk_ur * sizeof(float); + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block + * K_blk_ur * sizeof(float); + size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * N_blk * jcp.dimN_reg_block * sizeof(float); + size_t L2_block = M_L2_block + V_L2_block + U_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size); + }; + + if (test_MV_large_enough(jcp)) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimN_bcast_ur = 8; + /*dimK_block and dimK_ur*/ + size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1); + + jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block; + jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block; + for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) { + if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (!(jcp.dimN_block % N_blk)) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (!(jcp.dimM_block % M_blk) && blocking_ok()) { + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented; + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.sched_policy = WSCHED_WEI_SDGtWo; + set_jcp_WEI_params(jcp); + jcp.nthr = nstl::min(mkldnn_get_max_threads(), + jcp.tile_block); + return status::success; + } + } + } + } + } + } + } + return status::unimplemented; +} + +status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimN_bcast_ur = 8; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; + jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; + float C1 = 0.0, C2 = 0.0; + float C1_max = 0.5, C2_max = 1.4; + int N_blk, M_blk, K_blk_ur; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&]() -> bool { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float); + bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size) + && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size); + + size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block; + size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block; + size_t nb_K_blk = jcp.dimK / K_blk_ur; + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads; + if (!(nb_K_blk % nthreads)) { + load_balance = load_balance && (nb_K_blk % nthreads == 0); + } + + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float); + + size_t L2_block = V_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size); + return L1_cond && load_balance && L2_cond; + }; + + for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) { + if (jcp.dimK % K_blk_ur == 0) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (jcp.dimN_block % N_blk == 0) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (jcp.dimM_block % M_blk == 0) { + if (blocking_ok()) { + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; + } + } + } + } + } + } + } + jcp.dimK_reg_block = 1; + jcp.dimK_block = 1; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; +} +} // namespace +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) { + if (!mayiuse(avx512_core)) + return status::unimplemented; + else + jcp.ver = ver_avx512_core; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.prop_kind = cd.prop_kind; + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + jcp.mb = src_d.dims()[0]; + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /******************Kernel blocking Parameters ***********/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + + jcp.dimK = jcp.ntiles; + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_SDGtWo(jcp); + if (res == status::unimplemented) { + res = set_wsched_WEI_S_D_Giot_W(jcp); + assert(res == status::success); + } + return res; +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp new file mode 100644 index 0000000000..025a554d92 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp @@ -0,0 +1,291 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP +#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP + +#include "c_types_map.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + : public jit_generator { + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) { + { + this->weights_transform_data_ker_generate(); + weights_transform_data_ker + = (decltype(weights_transform_data_ker)) this->getCode(); + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->input_transform_data_ker_generate(); + input_transform_data_ker = (decltype(input_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->output_transform_data_ker_generate(); + output_transform_data_ker + = (decltype(output_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *, const int); + void (*input_transform_data_ker)(jit_wino_transform_call_s *); + void (*output_transform_data_ker)(jit_wino_transform_call_s *); + void (*weights_transform_data_ker)(jit_wino_transform_call_s *); + +protected: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(); + void input_transform_data_ker_generate(); + void output_transform_data_ker_generate(); + void weights_transform_data_ker_generate(); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + reg64_t reg_is_beta_zero = abi_param4; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; + + /* registers used for transforms*/ + reg64_t param = abi_param1; + + /* registers used for output_transform_data_ker */ + reg64_t oreg_temp = abi_not_param1; + reg64_t oreg_Ow = r9; + reg64_t oreg_src = r11; + reg64_t oreg_tile_block = r12; + reg64_t oreg_tile_block_ur = r13; + reg64_t oreg_nb_tile_block_ur = r14; + reg64_t oreg_O = r8; + reg64_t oreg_T = r10; + reg64_t oreg_dst = r11; + reg64_t oreg_ydim = r14; + reg64_t oreg_xdim = r15; + reg64_t oreg_out_j = r12; + reg64_t oreg_bias = rbx; + reg64_t imm_addr64 = rax; + + /* registers used for input_transform_data_ker */ + reg64_t ireg_temp = abi_not_param1; + reg64_t ireg_jtiles = rax; + reg64_t ireg_itiles = rbx; + reg64_t ireg_I = r8; + reg64_t ireg_src = r13; + reg64_t ireg_ydim = r14; + reg64_t ireg_xdim = r15; + reg64_t ireg_inp_j = r12; + reg64_t ireg_inp_i = rdx; + reg64_t ireg_mask_j = r11; + reg64_t ireg_mask = rsi; + reg32_t ireg_mask_32 = esi; + reg64_t ireg_zero = r9; + reg64_t ireg_Iw = r9; + reg64_t ireg_T = r10; + reg64_t ireg_tile_block = r12; + reg64_t ireg_tile_block_ur = r13; + reg64_t ireg_nb_tile_block_ur = r14; + reg64_t ireg_output = r15; + + /* registers used for wei transform */ + reg64_t wreg_temp = abi_not_param1; + reg64_t wreg_F = r8; + reg64_t wreg_src = r9; + reg64_t wreg_MT = r15; + reg64_t wreg_M = r14; + reg64_t wreg_dst = r10; + reg64_t wreg_dst_aux = r9; + reg64_t wreg_dst_idx = r8; + reg64_t wreg_Fw = r11; + reg64_t wreg_T = r12; + reg64_t wreg_cnt_j = rdx; + reg64_t wreg_F_aux = r14; + reg64_t wreg_Fw_aux = r15; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel + : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_t &src_md, + memory_desc_t &weights_md, const memory_desc_t &dst_md, + const primitive_attr_t &attr); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode(); + + align(); + const Xbyak::uint8 *addr = getCurr(); + this->src_transform_generate(); + src_transform = (decltype(src_transform))addr; + + if (jcp.with_bias) { + align(); + addr = getCurr(); + this->diff_dst_transform_generate(true); + diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr; + } + + align(); + addr = getCurr(); + this->diff_dst_transform_generate(false); + diff_dst_transform = (decltype(diff_dst_transform))addr; + + if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) { + align(); + addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + align(); + addr = getCurr(); + this->diff_weights_transform_generate(true); + diff_weights_transform = (decltype(diff_weights_transform))addr; + + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + align(); + addr = getCurr(); + this->diff_weights_transform_generate(false); + diff_weights_transform_accum = + (decltype(diff_weights_transform_accum))addr; + }; + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*src_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *); + void (*diff_weights_transform)(jit_wino_transform_call_s *); + void (*diff_weights_transform_accum)(jit_wino_transform_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void src_transform_generate(); + void diff_dst_transform_generate(bool with_bias); + void diff_weights_transform_generate(bool first_tile); + + /*registers common to transforms*/ + reg64_t reg_transp = abi_param1; + reg64_t reg_ti = rbx; + reg64_t reg_tj = abi_not_param1; + reg64_t reg_src = r8; + reg64_t reg_dst = r9; + reg64_t reg_G = rsi; /*TODO: check if this is ok*/ + reg64_t reg_temp = rsi; + + /*registers common to src/diff_dst transform*/ + reg64_t reg_I = r10; + reg64_t reg_ydim = r11; + reg64_t reg_xdim = r12; + reg64_t reg_src_offset = r13; + reg64_t reg_zero = r14; + reg64_t reg_tile_count = r15; + reg64_t reg_maski = rsi; + reg32_t reg_maski_32 = esi; + reg64_t reg_maskj = rdx; + + reg64_t reg_T = rax; + reg64_t reg_oc_ur = rax; + reg64_t reg_ic_simd = r14; + reg64_t reg_bias = r10; + + void gemm_loop_generate(bool is_first_tile); + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r9; + reg64_t reg_dimN_block_loop_cnt = r10; + reg64_t reg_nb_dimN_bcast_ur = r11; + reg64_t reg_dimK_block_loop_cnt = r12; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp new file mode 100644 index 0000000000..002010ffa2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp @@ -0,0 +1,1284 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "jit_generator.hpp" + +#include + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + // Below scales are applied to source and weights data accordingly + // because this winograd implementation + // transforms source which may increase values up to 4x + // and transforms weights which may increase values up to 9/4x + const float adj_src_scale = 1.f / 4.f; + const float adj_wei_scale = 4.f / 9.f; + // Winograd transforms need ic and oc to be multiples of 16 + const int load_block = 16; +} + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + void generate(); + + int reg_inp_ind(int i) { + assert(i < jcp.alpha * jcp.alpha); + return (31 - i); + } + + Xmm vreg_inp(int i) { + return Xmm(reg_inp_ind(i)); + } + + Zmm zmm_inp(int i) { + return Zmm(reg_inp_ind(i)); + } + + Xmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(15 - i); + } + Xmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + + int unsign_val_in_wino_domain; + + Reg64 reg_scratch_src_alpha = rdx; + Xmm xmm_src_alpha = Xmm(0); + Zmm zmm_src_alpha = Zmm(0); + + Reg64 reg_shift = rax; + Xmm xmm_shift = Xmm(1); + Xmm xmm_zero = Xmm(0); + + Reg64 reg_maskx = rbx; + Reg64 reg_masky = rsi; + Reg64 reg_nomask = reg_maskx; +}; + +void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { + Label ic_block_label; + Label end_label; + Label mask_label; + Label nomask_label; + + auto load_src = [=](bool mask) { + for (int y = 0; y < jcp.alpha; y++) { + if (mask) + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm_i = zmm_inp(y * jcp.alpha + x); + Xmm vreg_i = vreg_inp(y * jcp.alpha + x); + int inp_offset = sizeof(uint8_t) + * ((-jcp.t_pad + y) * jcp.iw * jcp.ic + + (-jcp.l_pad + x) * jcp.ic); + if (mask) { + kandw(r_mask, y_mask, x_mask(x)); + vmovdqu8(vreg_i | r_mask | T_z, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } else { + vmovdqu8(vreg_i, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + vpmovzxbd(zmm_i, vreg_i); // to int32 + vcvtdq2ps(zmm_i, zmm_i); // to fp32 + vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha + vcvtps2dq(zmm_i, zmm_i); // to int32 + vpmovusdb(vreg_i, zmm_i); // to u8 + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +# undef READ_PARAM + + mov(reg_maskx, ptr[reg_ptr_v_x_masks]); + mov(reg_masky, ptr[reg_ptr_v_y_masks]); + test(reg_maskx, reg_maskx); + jz(end_label, T_NEAR); // skip kernel if x mask is all 0's + test(reg_masky, reg_masky); + jz(end_label, T_NEAR); // skip kernel if y mask is all 0's + and_(reg_maskx, reg_masky); + mov(reg_nomask, reg_maskx); + not_(reg_nomask); // zero if x and y masks are all 1's + + xor_(reg_shift, reg_shift); + mov(reg_shift.cvt8(), (int8_t)-128); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + } + + mov(reg_scratch_src_alpha, float2int(adj_src_scale)); + + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + vmovq(xmm_src_alpha, reg_scratch_src_alpha); + vbroadcastss(zmm_src_alpha, xmm_src_alpha); + + test(reg_nomask, reg_nomask); + jz(nomask_label, T_NEAR); + load_src(true); + jmp(mask_label, T_NEAR); + L(nomask_label); + load_src(false); + L(mask_label); + + for(int y = 0; y < 4; y++) { + vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2)); + vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2)); + vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1)); + vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3)); + } + for(int x = 0;x < 4; x++) { + vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2)); + vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2)); + vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1)); + vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3)); + } + + vmovd(xmm_shift, reg_shift.cvt32()); + vpxor(xmm_zero, xmm_zero, xmm_zero); + vpshufb(xmm_shift, xmm_shift, xmm_zero); + + for (int i = 0; i < 16; i++) { + int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i); + if (i != unsign_val_in_wino_domain) + vpsubb(vreg_out(i), vreg_out(i), Xmm(1)); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(uint8_t) * load_block); + add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block); + } + dec(reg_ic_block); + jnz(ic_block_label, T_NEAR); + + L(end_label); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id < 8); + return Zmm(31 - id_reg_stg); + } + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Zmm(31 - id_reg_out); + } + Xmm xmm_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Xmm(31 - id_reg_out); + } + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id < 2); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_bias = Zmm(1); + Zmm vreg_prev_dst = Zmm(2); + Zmm zmm_bias_alpha = Zmm(2); + Xmm xmm_bias_alpha = Xmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_scratch_bias_alpha = r15; + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0) + || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0)); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1) + || jcp.dst_dt == data_type::u8; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { + Label oc_block_label; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for(int i = 0; i < 16; i++) { + int internal_offset = sizeof(int32_t) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for(int y = 0; y < jcp.alpha; y++) { + vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1)); + vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2)); + + vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2)); + vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3)); + } + for(int x = 0; x < jcp.m; x++) { + vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1)); + vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2)); + + vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2)); + vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3)); + } + + + if (jcp.with_bias) { + vmovq(xmm_bias_alpha, reg_scratch_bias_alpha); + vbroadcastss(zmm_bias_alpha, xmm_bias_alpha); + + auto bias_addr = ptr [ reg_ptr_bias ]; + switch (jcp.bia_dt) { + case data_type::f32: + case data_type::s32: vmovups(vreg_bias, bias_addr); break; + case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break; + case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break; + default: assert(!"unsupported dst data type"); + } + if (jcp.bia_dt != data_type::f32) + vcvtdq2ps(vreg_bias, vreg_bias); + vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha + } + for(int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]); + for(int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = jcp.typesize_out * + (y * jcp.ow * jcp.oc + x * jcp.oc); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + Xmm xmm = xmm_out(i); + vcvtdq2ps(zmm, zmm); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + if (maybe_relu(0)) + vmaxps(zmm, vreg_zero, zmm); + if (p_sum_scale) { // post_op: sum + vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(vreg_prev_dst | r_mask, addr); break; + case data_type::s8: + vpmovsxbd(vreg_prev_dst | r_mask, addr); break; + case data_type::u8: + vpmovzxbd(vreg_prev_dst | r_mask, addr); break; + default: assert(!"unknown dst_dt"); + } + if (jcp.dst_dt != data_type::f32) + vcvtdq2ps(vreg_prev_dst, vreg_prev_dst); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) + vmaxps(zmm, vreg_zero, zmm); + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(addr, zmm | r_mask); break; + case data_type::s8: + vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + case data_type::u8: + vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + default: assert(!"unknown dst_dt"); + } + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, wino_dst); + READ_PARAM(reg_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +# undef READ_PARAM + + if (jcp.with_bias) + mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale)); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + vpxord(vreg_zero, vreg_zero, vreg_zero); + + for (int i = 0; i < jcp.m; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + + int oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); { + loop_body(); + add(reg_aux_ptr_src, sizeof(int32_t) * load_block); + add(reg_aux_ptr_dst, jcp.typesize_out * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block); + } + dec(reg_oc_block); + jnz(oc_block_label, T_NEAR); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) + { + generate(); + ker_ = reinterpret_cast(const_cast(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert(31 - jcp.n2_block * jcp.m_block - i + > (jcp.ver == ver_vnni ? 0 : 2)); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst_b = r13; + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + Reg64 reg_mb = rbx; + Reg64 reg_nnb = abi_not_param1; + Reg64 reg_scratch = rdx; + Reg64 reg_K = rsi; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vreg_tmp, vreg_src, vreg_wei); + vpmaddwd(vreg_tmp, vreg_tmp, vreg_one); + vpaddd(vreg_acc, vreg_acc, vreg_tmp); + } + }; + + preamble(); +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); + READ_PARAM(reg_aux_dst_b, dst_b); +# undef READ_PARAM + + if (jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(vreg_one, _t); + } + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + int offset = jcp.typesize_acc * nb2 * jcp.n_block; + vmovups(vreg_out(nb2, m), + EVEX_compress_addr(reg_aux_dst_b, offset)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + mov(reg_K, jcp.k_chunks); + L(K_loop_label); + { + for (int k = 0; k < jcp.k2_block; k += 4) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int wei_offset + = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = jcp.typesize_in * m * jcp.K; + vpbroadcastd(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); + } + add(reg_aux_src2, jcp.typesize_in * 4); + add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); + } + } + dec(reg_K); + jnz(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block); + vmovups(EVEX_compress_addr(reg_aux_dst2, offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K); + add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N); + } + dec(reg_mb); + jnz(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K); + + dec(reg_nnb); + jnz(nnb_loop_label, T_NEAR); + } + + postamble(); +} +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + if (jcp.ver == ver_vnni) { + return (jcp.mb <= mkldnn_get_max_threads() + && (jcp.mb > 4 + && jcp.ic > 64 + && !(jcp.oc > 128 && jcp.ih < 14))) + || jcp.mb > mkldnn_get_max_threads(); + } + return true; +} +} + +status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t +::init_conf(jit_conv_conf_2x3_wino_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core) && + src_d.data_type() == data_type::u8 + && wei_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // block sizes needed for GEMM kernel + jcp.ic_block = 4; + jcp.oc_block = 16; + + bool ok = true + && jcp.ngroups == 1 + && jcp.oc % load_block == 0 && jcp.ic % load_block == 0 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && everyone_is(3, jcp.kh, jcp.kw) + && everyone_is(1, jcp.stride_h, jcp.stride_w) + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad + && one_of(jcp.t_pad, 0, 1) + && one_of(jcp.l_pad, 0, 1); + if (!ok) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_acc = sizeof(int32_t); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + + int aa = jcp.alpha * jcp.alpha; + int L1_cap = get_cache_size(1, true); + int L2_cap = get_cache_size(2, true); + // need 1 extra reg for bcast, and 2 tmp regs for non-vnni + int free_regs = jcp.ver == ver_vnni ? 31 : 29; + + auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float thr_eff; + float Z = (float)jcp.ic + jcp.oc; + float Y = (float)jcp.ic * jcp.oc; + if (small_mb == 0) { // outer par + int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + } else { // inner par + int tranw = iy * ix / jcp.alpha; + int gemmw = aa * (jcp.nb_oc / n2_b); + int tranw_r = rnd_up(tranw, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y); + } + return thr_eff; + }; + + auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float mem_eff, req_mem; + int M = ix * iy / jcp.alpha; + if (small_mb == 0) { // outer parallelization strategy + // memory for wino transforms (other memory has poor reuse) + req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc); + mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f; + } else { // inner parallelization strategy + // memory used during gemm + int N = jcp.oc_block * n2_b; + req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + // memory used during wino transforms + int M_per_thr = div_up(M, jcp.nthr); + req_mem = (float)aa * M_per_thr + * (jcp.ic + jcp.typesize_acc * jcp.oc); + if (req_mem > L2_cap) + mem_eff = 0.1f; + } + return mem_eff; + }; + + auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff, + float mem_eff, float reg_eff) { + // these coefficients are chosen empirically + float mem_fac = 0.1f, reg_fac = 0.2f; + // normalized overhead relative to memory and register components + float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff; + // thread and work components affect all others + tot_eff *= thr_eff * work_eff; + return tot_eff; + }; + + auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff, + int &m_block, int &n2_block, float &tot_eff) { + int M = (ix * iy) / jcp.alpha; + int max_m_block = nstl::min(M, free_regs); + int max_n2_block = nstl::min(jcp.nb_oc, free_regs); + tot_eff = 0.f; + for (int im = max_m_block; im > 0; im--) { + if (M % im) + continue; + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = (im + 1) * in2; + float mem_eff = get_mem_eff(small_mb, ix, iy, in2); + float reg_eff = (float)(im * in2) / (im + in2); + float thr_eff = get_thr_eff(small_mb, ix, iy, in2); + float cur_tot_eff = get_tot_eff( + small_mb, thr_eff, work_eff, mem_eff, reg_eff); + if (jcp.nb_oc % in2 || used_regs > free_regs + || cur_tot_eff <= tot_eff) + continue; + tot_eff = cur_tot_eff; + m_block = im; + n2_block = in2; + } + } + }; + + /* Selecting xb and yb blocking */ + int min_yb = jcp.m; + int min_xb = jcp.m; + int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2)); + int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2)); + float best_eff = 0.f; + for (int ix = min_xb; ix <= max_xb; ix += 2) { + assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2); + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2); + + int m_b[2]; + int n2_b[2]; + bool small_mb; + float inner_eff, outer_eff, work_eff; + + int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix); + work_eff = (float)jcp.oh * jcp.ow / tiled_area; + if (best_eff > 0.f && work_eff < 4.f / 9.f) + continue; // no gain from Winograd transformation + + /* outer parallelization */ + find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff); + + /* inner parallelization */ + find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff); + + small_mb = inner_eff > outer_eff; + float eff = small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.m_block = m_b[small_mb]; + jcp.n2_block = n2_b[small_mb]; + jcp.small_mb = small_mb; + } + } + } + + assert((jcp.m_block + 1) * jcp.n2_block <= free_regs); + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.mb_block = 1; + if (jcp.small_mb) { + // For small mb harness, set mb_block as large as possible subject to + // the constraint that winograd activations fit into available L3 cache + int L3_cap = get_cache_size(3, true); + int M = jcp.xb * jcp.yb / 4; + int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in; + int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc; + int max_mb_block = nstl::min( + jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size)); + for (int i = max_mb_block; i > 1; i--) { + if (jcp.mb % i == 0) { + jcp.mb_block = i; + break; + } + } + } + jcp.nb_mb = jcp.mb / jcp.mb_block; + + jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4; + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block; + + // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4 + // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is + // a multiple of load_block = 16, we just use that for now. + jcp.k2_block = load_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + memory_desc_t expect_wei_md = wei_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::s8; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_aaOIoi; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = adj_wei_scale; + + size_t max_size = types::data_type_size(data_type::s8) * + jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + max_size += types::data_type_size(data_type::s32) * + jcp.alpha * jcp.alpha * jcp.oc; + wd.size = max_size; + + if (wei_md.format_kind == format_kind::any) + wei_md = expect_wei_md; + if (wei_md != expect_wei_md) + return status::unimplemented; + + const int tilesize = jcp.alpha * jcp.alpha; + const int numtiles = jcp.M; + const int alltiles = numtiles * tilesize; + + jcp.size_wino_src + = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K) + / jcp.typesize_in; + jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic; + jcp.size_wino_dst = alltiles * jcp.oc; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +template +status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + pd_t::jit_conf() { + return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr()); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t:: +init_scratchpad() { + auto scratchpad = this->scratchpad_registry().registrar(); + + int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr; + scratchpad.book(key_wino_V, + sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K); + scratchpad.book(key_wino_M, + sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K); + + dim_t scale_count = attr()->output_scales_.count_; + scratchpad.book(key_conv_adjusted_scales, + sizeof(float) * nstl::max(scale_count, 16)); +} + +template +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +template +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +template +const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +adjust_oscales(const memory_tracking::grantor_t &scratchpad) const { + const float *oscales = pd()->attr()->output_scales_.scales_; + auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / (adj_src_scale * adj_wei_scale); + if (count == 1) + utils::array_set(loc_scales, oscales[0] * factor, 16); + else + for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor; + return loc_scales; +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const auto &jcp = kernel_->jcp; + if (jcp.small_mb) + execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx)); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src_base = scratchpad.template get(key_wino_V); + auto wino_dst_base = scratchpad.template get(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = wino_src_base + jcp.size_wino_src * ithr; + auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + auto gemm_p = + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + mb * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + // start threads at different GEMMs to help bring weights into LLC + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + gemm_p.dst_b = dst_bias + jcp.bia_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + mb * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +template +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: +execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src = scratchpad.template get(key_wino_V); + auto wino_dst = scratchpad.template get(key_wino_M); + + for (int mbb = 0; mbb < jcp.nb_mb; mbb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t:: + call_params_t(); + + gemm_p.src = wino_src + jcp.inp_stride * tile_ij; + gemm_p.dst = wino_dst + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp new file mode 100644 index 0000000000..9e6e57b051 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp @@ -0,0 +1,128 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t; +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t; +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t; + +template +struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""), + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::u8, data_type::s8, + data_type::undef, dst_data_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_conf(); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + init_scratchpad(); + + return status; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(); + void init_scratchpad(); + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nhwc, any, nhwc); + } + }; + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type acc_data_t; + typedef typename prec_traits::type dst_data_t; + + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd); + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad) + const; + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_; + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_; + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp new file mode 100644 index 0000000000..f4ec29ab00 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -0,0 +1,820 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off)); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + int output_offset = jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep; + + add(aux_reg_output_data, output_offset); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in, + zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load) { + return Zmm(ur * load_loop_blk + i_load); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto zmm_bias_alpha = [=]() { + return Zmm(ur * load_loop_blk); + }; + + auto xmm_bias_alpha = [=]() { + return Xmm(ur * load_loop_blk); + }; + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_bia * jcp.oc_block * i_load); + }; + + auto comp_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_comp_data, + sizeof(int32_t) * jcp.oc_block * i_load); + }; + + auto scale_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_ptr_scales, + jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load)); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + + int offt = (jcp.ic_without_padding * i_ur + i_reduce); + + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + + int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + return EVEX_compress_addr(aux_reg_output_data, + jcp.typesize_out * (jcp.oc_without_padding * i_ur + + i_load * jcp.load_block)); + }; + + auto init = [=]() { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } + }; + + auto store = [=](const bool mask_flag_in) { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + if (p_sum_scale && *p_sum_scale != 1.f) { + mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data); + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + } + if (jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_scratch, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_scratch); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; + auto zmm_bias = zmm_tmp; + auto zmm_comp = zmm_bcast; + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, + EVEX_compress_addr(rsp,reg_bias_data_off)); + cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vcvtdq2ps(r, r); + if (jcp.signed_input) + vaddps(r, r, zmm_comp); + if (jcp.with_bias) + vaddps(r, r, zmm_bias); + + zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r; + vmulps(mask_zmm, r, scale_ptr(i_load)); + } + } + + if (maybe_eltwise(0)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + if (p_sum_scale) { // post_op: sum + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto zmm_prev_dst = zmm_zero; + + auto r = vreg_accum(i_load, i_ur); + cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur), + mask_flag); + + if (*p_sum_scale == 1.f) + vaddps(r, zmm_prev_dst); + else + vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + + if (maybe_eltwise(1)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(r, zmm_zero, r); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(r, r); + } + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + zmm_t r_zmm = mask_flag ? r | ktail_mask : r; + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::s8: + vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::u8: + vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off)); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + auto fma_block = [=](bool last_block) { + int reduce_step = 4; + int tail_size = jcp.ic_without_padding % reduce_step; + int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding + ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) + : jcp.reduce_loop_unroll; + for (int i_reduce = 0; i_reduce < loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (last_block && tail_size != 0 + && i_reduce == loop_unroll - reduce_step) { + Xmm xmm_bcast = Xmm(zmm_bcast.getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data + + jcp.ic_without_padding * i_ur + i_reduce + r], r); + vpbroadcastd(zmm_bcast, xmm_bcast); + } else { + vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false)); + } + if (jcp.signed_input) + vpsubb(zmm_bcast, zmm_bcast, zmm_shift); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + compute(vreg_accum(i_load, i_ur), + vreg_load(i_load), zmm_bcast); + } + } + } + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + if (jcp.ic != jcp.ic_without_padding) { + fma_block(true); + } else { + fma_block(false); + } + + if (jcp.oc_without_padding != jcp.oc) { + Label end_store, common_store; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + + /*Check if it is the last load_loop_blk*/ + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + cmp(reg_load_loop_work, 0); + jg(common_store, T_NEAR); + + /*Check if it is the last ocb*/ + test(reg_reduce_pos_flag, FLAG_OC_LAST); + jz(common_store, T_NEAR); + + store(true); + jmp(end_store, T_NEAR); + + L(common_store); + store(false); + + L(end_store); + + add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + } else { + store(false); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() +{ + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + sub(rsp, stack_space_needed); + + if (jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_last_load.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + if (jcp.signed_input) { + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_bia); + if (jcp.signed_input) + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + add(reg_comp_data, + load_loop_blk * jcp.load_block * sizeof(int32_t)); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + add(reg_ptr_scales, + jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float)); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + add(reg_output_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast); + const int *ur_cases_fma = ur_cases_fma_expl_bcast; + const int *ur_cases = ur_cases_fma; + const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + + for (int _i = 1; _i <= label_idx + 1; _i++) { + prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]); + prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]); + } + + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( + jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_core)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) + || weights_d.data_type() != data_type::s8 + || !one_of(dst_d.data_type(), + data_type::f32, data_type::s32, data_type::s8, data_type::u8)) + return status::unimplemented; + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + format_tag_t dat_tag = format_tag::nhwc; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 16; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + const int SMALL_SPATIAL = 7 * 7; + const int BIG_REDUCE_DIM = 1024; + + int load_blocking = 0; + int load_blocking_max = 0; + int bcast_blocking = 0; + int bcast_blocking_max = 0; + int reduce_blocking = 0; + int reduce_blocking_max = 0; + jcp.load_grp_count = 1; + jcp.use_vmovntps = false; + + const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in); + const int L2_capacity = (L2_size * 3) / 4; + + int size_treshold = 28; + int max_regs = 0; + int min_regs = 6; + if (jcp.ver == ver_vnni) + max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold) + && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9; + else + max_regs = 8; + jcp.expl_bcast = true; + + if (jcp.mb == 1 && jcp.ic > 128 + && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) { + if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) + max_regs = min_regs; // mobilenet_v2 performance improvement + jcp.ur = nstl::min(max_regs, jcp.os); + } else { + const int spatial = jcp.oh; + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i--) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + } + + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + jcp.load_loop_iter_step = jcp.load_block; + + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + reduce_blocking = nb_reduce; + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 64; + else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + + bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; + if (cmp_reduce) + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + load_blocking = jcp.load_dim; + + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); // + load_blocking = jcp.load_block; + } + + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + + assert(jcp.reduce_loop_unroll % 4 == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + // miniumum size of load dim chunk for work distribution within threads + jcp.nb_load_chunk = 1; + // peformance improvements for googlenet_v3, mb=1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4 + && jcp.ic * jcp.oc <= L2_size) { + jcp.nb_load_chunk = 4; + jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count); + } + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp new file mode 100644 index 0000000000..22e9732a1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP +#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t) + jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); + + bool maybe_eltwise(int position); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using mask_t = const Xbyak::Opmask; + + reg64_t reg_bcast_data = r8; + reg64_t reg_ptr_scales = r8; + reg64_t reg_output_data = r9; + reg64_t reg_load_data = r10; + reg64_t reg_ptr_sum_scale = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t reg_bias_data = r12; + reg64_t reg_comp_data = r12; + reg64_t reg_scratch = r13; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = r15; + reg64_t reg_reduce_pos_flag = rax; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t reg_bcast_loop_work = rbx; + reg64_t bcast_loop_iter = rdx; // Note: Fix me + reg64_t reg_load_loop_work = rsi; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reduce_loop_iter = abi_param1; + + reg64_t reg_last_load = r8; + mask_t ktail_mask = k6; + + mask_t vmask = k7; + + Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28); + Xbyak::Zmm zmm_one = Xbyak::Zmm(29); + Xbyak::Zmm zmm_zero = Xbyak::Zmm(30); + Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31); + Xbyak::Zmm zmm_shift = Xbyak::Zmm(30); + + Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31); + Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31); + + int bcast_loop_work_off = 0; + int reg_bias_data_off = 8; + int reg_bcast_data_off = 16; + int reg_load_data_off = 24; + int reg_ptr_sum_scale_off = 32; + int reg_comp_data_off = 40; + int stack_space_needed = 48; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp new file mode 100644 index 0000000000..0bf09fc677 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -0,0 +1,292 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +namespace { +template +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const T grp_size = utils::div_up(nthr, nx_divider); + const T grp_count = utils::div_up(nthr, grp_size); + + T grp = ithr / grp_size; + T grp_ithr = ithr % grp_size; + T grp_nthr = grp_size; + T first_grps = nthr % grp_count; + if (first_grps > 0 && grp >= first_grps) { + ithr -= first_grps * grp_size; + grp_nthr--; + grp = ithr / grp_nthr + first_grps; + grp_ithr = ithr % grp_nthr; + } + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} + +/* convolution forward */ +template +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { + auto local_scales = scratchpad.template get( + key_conv_adjusted_scales); + auto scales = pd()->attr()->output_scales_.scales_; + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, scales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = scales[c] * factor; + } + } + + parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); +} + +template +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t +::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get(key_conv_rtus_space); + auto local_scales = scratchpad.get(key_conv_adjusted_scales); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + const int stride_h = pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[1]; + const int pad_t = pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][1]; + + const auto &oscales = pd()->attr()->output_scales_; + + int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block) + * jcp.oc_block * jcp.ic_block; + wei_data_t *w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(w + offset) : 0; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t::call_params_t(); + const int nb_oc = jcp.nb_load; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, + jcp.load_grp_count); + if (jcp.nb_load_chunk > 1) { + ocb_start *= jcp.nb_load_chunk; + ocb_end *= jcp.nb_load_chunk; + } + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + + if (ocb + load_step >= nb_oc) + p.first_last_flag |= FLAG_OC_LAST; + else + p.first_last_flag &= ~FLAG_OC_LAST; + + }; + + auto init_reduce = [&]() + { + p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int n, int g, int oh, int ow, + int ih, int iw) + { + const int icb = 0; // Start from the first IC block + const int _ocb = g * nb_oc + ocb; + const int _icb = g; + + const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow); + + p.output_data = &dst[dst_off]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; + p.compensation = (jcp.signed_input) + ? &compensation[_ocb * jcp.oc_block] : 0; + p.scales = (jcp.signed_input && jcp.ver != ver_vnni) + ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block] + : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + init_reduce(); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + init_reduce(); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + +using namespace data_type; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp new file mode 100644 index 0000000000..ad9027ac17 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""), + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t< + src_type, dst_type>); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), format_tag::any, + dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel:: + init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), + *weights_md(1), *attr(), mkldnn_get_max_threads(), + rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + scratchpad, jcp_, *attr()); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + }; + + template + friend void init_rtus_driver(conv_t *self); + + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver(this); + } + + ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_; + rtus_driver_t *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp new file mode 100644 index 0000000000..e89d068302 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t + : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t); + + status_t init_convolution() { + convolution_desc_t cd; + status_t status; + + auto dd = desc(); + status = conv_desc_init(&cd, prop_kind::forward_training, + alg_kind::convolution_direct, &(dd->src_desc), + &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc), + dd->strides, dd->dilates, dd->padding[0], dd->padding[1], + dd->padding_kind); + + if (status == status::success) { + status = mkldnn_primitive_desc::create( + &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr); + } + + if (status == status::success) + status = set_default_params(); + + return status; + }; + + status_t init() { + bool ok = true + && is_fwd() + && desc()->alg_kind == alg_kind::deconvolution_direct + && !has_zero_dim_memory() + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && desc()->weights_desc.data_type == data_type::s8 + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + CHECK(init_convolution()); + + return status::success; + } + + virtual void init_scratchpad_md() override { + const auto conv_1x1_pd = static_cast(conv_pd_); + scratchpad_md_ = *conv_1x1_pd->scratchpad_md(); + } + + protected: + status_t set_default_params() { + auto conv_1x1_pd_ = static_cast(conv_pd_); + src_md_ = *conv_1x1_pd_->src_md(); + dst_md_ = *conv_1x1_pd_->dst_md(); + weights_md_ = *conv_1x1_pd_->weights_md(); + if (with_bias()) + bias_md_ = *conv_1x1_pd_->weights_md(1); + return status::success; + } + + using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t + ::pd_t; + friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t; + primitive_desc_t *conv_pd_; + }; + + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + + ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t() + { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + return conv_p_->execute(ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +} +} +} + +#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp new file mode 100644 index 0000000000..10e98a00c4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -0,0 +1,1182 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { +void pick_loop_order(jit_conv_conf_t &jcp, int nthr) +{ + jcp.loop_order = loop_cwgn; + if (jcp.ngroups > 1) { + jcp.loop_order = loop_ngcw; + if (jcp.mb < nthr) + jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; + } +} +} + +template +bool _jit_avx512_core_x8s8s32x_fwd_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::prepare_output(int ur_w) +{ + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + for (int k = 0; k < nb_oc_block; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + if (jcp.is_depthwise && !jcp.is_fast_depthwise) { + Reg32 _t32 = reg_scratch.cvt32(); + mov(_t32, (uint32_t)128); + vpbroadcastd(vmm_shift, _t32); + } else { + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)128); + vpbroadcastb(vmm_shift, _t8); + } + } +} + +template +const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel:: + vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) { + return vmm_in; +} + +template<> +const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel:: + vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) { + return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) + : zmm_in; +} + + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in, + const Vmm vmm_in, const Operand &op, bool mask_flag) { + //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in; + const Vmm vmm = vmm_mask(vmm_in, mask_flag); + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(vmm, op); break; + case data_type::s8: vpmovsxbd(vmm, op); break; + case data_type::u8: vpmovzxbd(vmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(vmm_in, vmm_in); +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + if (ur_w == jcp.ur_w) + eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w); + else + for (int k = 0; k < nb_oc_block; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( + int ur_w, bool last_oc_block_flag) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; + + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = nullptr; + if (sum_idx != -1) { + const auto &p_entry = p.entry_[sum_idx]; + p_sum_scale = &p_entry.sum.scale; + } + + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.signed_input && jcp.ver != ver_vnni) { + /* put 'wei_adj_scale = 0.5' for bias calculation */ + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * k * oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + + cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + /* bias *= 0.5 */ + vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * k * oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + + cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag); + } + /* add to zmm_accum: compensation, bias and permute */ + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.is_fast_depthwise) + vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k)); + vcvtdq2ps(vmm, vmm); + if (jcp.signed_input) + vaddps(vmm, vmm, vmm_comp); + if (jcp.with_bias) + vaddps(vmm, vmm, vmm_bias); + + const Vmm vmm_k = vmm_mask(vmm, mask_flag); + vmulps(vmm_k, vmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + + /* Do post-ops */ + if (maybe_eltwise(0)) compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + Vmm vmm = vmm_out(j, k); + cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(vmm, vmm_prev_dst); + else + vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) compute_eltwise(ur_w); + + /* write out register to output_addr */ + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.dst_dt == data_type::u8) { + vpxord(vmm_zero, vmm_zero, vmm_zero); + vmaxps(vmm, vmm_zero, vmm); + } + + if (jcp.dst_dt != data_type::f32) { + /* Note: using Zmm for rounding in Xmm/Ymm kernel + because there is no instruction to do rounding + from Xmm/Ymm -> Xmm/Ymm. + Embedded rounding is not supported for Xmm. + TODO: maybe avoid Zmm if it helps performance.*/ + Zmm zmm = zmm_out(j, k); + vcvtps2dq(zmm, zmm); + } + } + + for (int j = 0; j < ur_w; j++) { + int aux_output_offset = jcp.typesize_out + * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + + Vmm vmm = vmm_out(j, k); + const Vmm r_vmm = vmm_mask(vmm, mask_flag, true); + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_vmm); break; + case data_type::s8: vpmovsdb(addr, r_vmm); break; + case data_type::u8: vpmovusdb(addr, r_vmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + assert(!"invalid group blocking for depthwise convolution"); +} + +template <> +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + + auto input_spatial_index = [=](int oi, int ki) { + return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); + }; + + auto input_offset2 = [=](int ii, int ci) { + return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); + }; + + auto input_offset3 = [=](int oi, int ci, int ki) { + return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); + }; + + auto kernel_offset = [=](int ci, int ki) { + return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + // okay for depthwise since src is zero-extended + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddwd(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + int ii_start = 0; + int ii_end = -1; + if (jcp.is_resrc_depthwise && !h_padded) { + // find bounds of input spatial indices + bool first = true; + for (int ki = 0; ki < jcp.kw; ki++) { + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + for (int oi = oi_start; oi < oi_end; oi++) { + int ii = input_spatial_index(oi, ki); + if (first || ii < ii_start) + ii_start = ii; + if (first || ii > ii_end) + ii_end = ii; + first = false; + } + } + } + + if (jcp.signed_input) { + vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero); + vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift); + } + for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { + const bool mask_flag = last_ic_block_flag != no_last_block + && ci == jcp.nb_ch_blocking - 1; + if (jcp.is_resrc_depthwise && !h_padded) { + // now we can load input once and reuse up to jcp.kw times + for (int ii = ii_start; ii <= ii_end; ii++) { + int aux_input_offset = input_offset2(ii, ci); + const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking); + const Zmm zmm_inp_msk = mask_flag + ? zmm_inp_tmp | ktail_mask | T_z + : zmm_inp_tmp; + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } else { + vpmovzxbd(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift); + } + } + for (int ki = 0; ki < jcp.kw; ki++) { + int aux_kernel_offset = kernel_offset(ci, ki); + if (jcp.is_fast_depthwise) { + vbroadcasti32x4(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei); + } else { + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + if (h_padded) { + assert(jcp.signed_input); + for (int oi = 0; oi < ur_w; oi++) + compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); + } else { + const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src; + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + int start_ = jcp.signed_input ? 0 : oi_start; + int end_ = jcp.signed_input ? ur_w : oi_end; + for (int oi = start_; oi < end_; oi++) { + if (oi >= oi_start && oi < oi_end) { + if (jcp.is_resrc_depthwise) { + int ii = input_spatial_index(oi, ki); + zmm_src = zmm_inp(ii, jcp.nb_ch_blocking); + } else { + int aux_input_offset = input_offset3(oi, ci, ki); + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } else { + vpmovzxbd(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_src, zmm_src, vmm_shift); + } + } else if (jcp.signed_input) { + zmm_src = zmm_shifted_zero; + } + compute(zmm_out(oi, ci), zmm_wei, zmm_src); + } + } + } + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, + int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + if (jcp.is_depthwise) + return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); + + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ch_block_all = jcp.ch_block * ic_block * oc_block; + + int nb_oc_block = jcp.nb_oc_blocking; + + auto input_offset = [=](int oi, int ic, int ki) { + return jcp.typesize_in + * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * jcp.ic_without_padding * jcp.ngroups + 4 * ic); + }; + auto kernel_offset = [=](int ii, int ic, int ki) { + return jcp.typesize_in + * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all + + 4 * ic * oc_block); + }; + auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); + vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); + vpaddd(vreg_acc, vreg_acc, vmm_tmp); + } + }; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + int tail_size = jcp.ic_without_padding % 4; + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */ + int icb = (last_ic_block_flag != no_last_block) + ? div_up((jcp.ic_without_padding % ic_block), 4) + : ic_block / 4; + for (int ic = 0; ic < icb; ic++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Vmm inp = vmm_inp(0,nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } else { + for (int jj = _start; jj < _end; jj++) { + int aux_input_offset = input_offset(jj, ic, ki); + if (jj >= jj_start && jj < jj_end) { + if (last_ic_block_flag == last_sp_block + && tail_size != 0 && ic == icb - 1) { + Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_inp + aux_input_offset + r], r); + vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp); + } else { + vpbroadcastd(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr( + aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(vmm_inp(jj, nb_oc_block), + vmm_inp(jj, nb_oc_block), vmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Vmm inp = vmm_inp(jj, nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } + } + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, ic, ki); + vmovups(vmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + for (int jj = _start; jj < _end; jj++) { + Vmm inp = (h_padded == true) + ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block); + compute(vmm_out(jj, ii), vmm_wei, inp); + } + } + } + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { + Label kh_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, + b_overflow_label, no_b_overflow_label; + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ic_without_padding * jcp.ngroups; + + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + if ((jcp.signed_input) || (!jcp.signed_input && + (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { + cmp(reg_kj, 0); + je(skip_kh_loop, T_NEAR); + } + L(kh_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_sp_block) +{ + prepare_output(ur_w); + + // IC loop + Label icb_label; + mov(reg_icb, jcp.nb_ic); + L(icb_label); + if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + + cmp(reg_icb, 1); // The last IC block + jne(common_ker, T_NEAR); + + kh_loop(ur_w, pad_l, pad_r, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, pad_l, pad_r, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, pad_l, pad_r, no_last_block); + } + // End of IC Loop + int inp_step = jcp.ic_block; + int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block; + add(reg_inp, jcp.typesize_in * inp_step); + add(reg_ker, jcp.typesize_in * ker_step); + + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_label, T_NEAR); + + sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); + sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + + jne(common_store, T_NEAR); + + store_output(ur_w, true); // last oc block + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + } else { + store_output(ur_w, false); + } +} + +template +void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() +{ + Label permute_index_table; + int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad) + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift = jcp.typesize_in * + (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding + * jcp.ngroups); + int out_shift = jcp.typesize_out * + (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); + preamble(); + + if (jcp.is_depthwise) { + int idx = jcp.max_regs_ur - 1; + if (!jcp.is_resrc_depthwise) + zmm_src = Zmm(++idx); + if (jcp.ver != ver_vnni) + zmm_tmp = Zmm(++idx); + if (jcp.is_fast_depthwise) + zmm_permute = Zmm(++idx); + if (jcp.signed_input) { + zmm_shifted_zero = Zmm(++idx); + ++idx; // due to extra register used for shifts and compensations + } + assert(idx == ker_dw_reg_base_idx); + } + + if (!jcp.is_depthwise && jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t16 = reg_scratch.cvt16(); + mov(_t16, 0x1); + vpbroadcastw(vmm_one, _t16); + } + + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise + ? jcp.ngroups % jcp.ch_block + : jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + Reg32 regw_tmp = reg_oi.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + if (jcp.is_fast_depthwise) { + // prepare mask register for blending weights + mov(reg_scratch, 0x8888444422221111); + kmovq(kblend_mask, reg_scratch); + // load permute indices from data section + mov(reg_scratch, permute_index_table); + vmovdqu32(zmm_permute, ptr[reg_scratch]); + } + + int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int n_oi = jcp.ow / jcp.ur_w; + int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + + if (jcp.nb_ow == 1) { + if (r_pad1 > 0 || jcp.ur_w_tail == 0) + n_oi--; + + xor_(reg_oi, reg_oi); + if (jcp.ow == jcp.ur_w) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true); + } else { + if (n_oi == 0) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } else { + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + inc(reg_oi); + } + if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) + { + Label ow_loop_label; + L(ow_loop_label); { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0 || jcp.ur_w_tail == 0) { + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label, + oi_loop_label, oi_loop_end_label; + + assert(jcp.ow_block % jcp.ur_w == 0); + int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + int n_oi_last_ow_block + = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w; + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded + = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded + = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + mov(reg_oi, n_oi_first_ow_block); + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + L(middle_ow_blocks_label); + + if (jcp.l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + if (n_oi_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + } + + if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + } + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); { + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + icb_loop(jcp.ur_w, 0, 0, false); + + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + + jmp(oi_loop_label, T_NEAR); + } + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + // that is last block + if (!last_ow_block_padded) + jmp(tail_label, T_NEAR); + + // last oi block with right padding + L(last_oi_label); + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + // ur_w tail + L(tail_label); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); + + if (jcp.is_fast_depthwise) { + align(64); + L(permute_index_table); + const uint32_t _idx[] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) + dd(_idx[i]); + } +} + +bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) +{ + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) || + (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr, + int nthreads) +{ + using namespace prop_kind; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + bool is_1d = ndims == 3; + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.ic_block = 1; + jcp.oc_block = 1; + } else { + jcp.ch_block = 1; + jcp.ic_block = 16; + jcp.oc_block = 16; + + if (jcp.ngroups == 1) { + /* For non grouped convolutions, pad channels by 16 if needed */ + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) { + /* For grouped convolutions, MKL-DNN doesn't support padding. + Use Ymm when channels per group is multiple of 8, + Xmm when channels per group is multiple of 4 */ + jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4; + jcp.oc_block = jcp.ic_block; + } + if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0) + return status::unimplemented; + } + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core; + jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni + && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 + // would require byte masking + // for load from src + jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw + && jcp.kw < 4 && jcp.dilate_w == 0; + if (jcp.is_depthwise) { + jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise + - 2 * jcp.signed_input - (jcp.ver != ver_vnni); + } else { + jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28; + } + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + format_tag_t wei_tag; + if (jcp.ic_block == 16 || jcp.ch_block == 16) { + if (is_1d) { + wei_tag = with_groups + ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i + : OIw4i16o4i; + } else { + wei_tag = with_groups + ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i + : OIhw4i16o4i; + } + } else if (with_groups && jcp.ic_block == 8) { + wei_tag = gOIhw2i8o4i; + } else + wei_tag = gOIhw4o4i; + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) + int nb_ch_blocking = 4; + for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--) + if (jcp.nb_ch % nb_ch_blocking == 0) + break; + jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; + + // If OC blocking is incommensurate with the number of OC blocks (general + // requirement for all convolutions), or if it results in an unrolling + // factor smaller than the left padding (special requirement for SSD:fc6), + // then search for a smaller OC blocking that satisfies both constraints. + auto is_oc_blocking_ok = [&](int block) { + int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); + return jcp.nb_oc % block == 0 + && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1; + }; + + // choose nb_oc work chunk size for distribution within threads + int max_threading_nb_oc_chunk = 4; + // Performance improvements for googlenet_v3 and resnet_50 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3 + && jcp.stride_w == 1 && jcp.ic % 64 == 0) + max_threading_nb_oc_chunk = 2; + jcp.nb_oc_blocking_thr_chunk = + nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); + for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) { + if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) + break; + } + + // choose oc blocking for computational kernel + jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; + // Performance improvements for googlenet_v3 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + const int size_treshold_for_nb_oc_blocking_reduction = 17; + if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction + && jcp.stride_w == 1 + && !(jcp.kh == 1 && jcp.kw == 3) + && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) { + const int max_nb_oc_blocking = 2; + jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0 + && is_oc_blocking_ok(jcp.nb_oc_blocking)) + break; + } + + if (jcp.is_resrc_depthwise) + jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) + / (jcp.nb_ch_blocking + jcp.stride_w); + else + jcp.ur_w + = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking + : jcp.nb_oc_blocking + 1); + if (jcp.ow < jcp.ur_w) + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.ow_block = jcp.ow; + int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh + * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); + float best_thr_eff + = (float)base_work_amount / rnd_up(base_work_amount, nthreads); + int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w); + for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow); + if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block + && best_thr_eff > 0.8f) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + auto work_amount = base_work_amount * nb_ow; + float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads); + if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) { + jcp.ow_block = ow_block; + best_thr_eff = thr_eff; + } + if (best_thr_eff > 0.9f) + break; + } + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0); + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp, nthreads); + + jcp.nb_ic_L2 = jcp.nb_ic; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp new file mode 100644 index 0000000000..d8a05ad53e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp @@ -0,0 +1,239 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t) + + enum { STATE_FIRST_DST_LOAD = 0x1U }; + + _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_core_x8s8s32x_fwd_kernel() { + delete eltwise_injector_; + } + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + ker_dw_reg_base_idx = 30, + }; + typedef enum { + no_last_block, + last_ic_block, + last_sp_block, + } ic_block_t; + + /* data regs */ + const Xbyak::Reg64 reg_ptr_scales = rax; + const Xbyak::Reg64 reg_inp = r8; + const Xbyak::Reg64 reg_ker = r9; + const Xbyak::Reg64 reg_out = r10; + const Xbyak::Reg64 aux_reg_inp = r11; + const Xbyak::Reg64 reg_ptr_sum_scale = r11; + const Xbyak::Reg64 aux_reg_ker = r12; + const Xbyak::Reg64 reg_compensation = r14; + /* counter regs */ + const Xbyak::Reg64 reg_bias_alpha = abi_not_param1; + const Xbyak::Reg64 reg_oi = rbx; + const Xbyak::Reg64 reg_bias = rdx; + const Xbyak::Reg64 reg_oc_blocks = rsi; + const Xbyak::Reg64 reg_owb = aux_reg_ker; + const Xbyak::Reg64 reg_scratch = reg_compensation; + const Xbyak::Reg64 reg_kj = reg_ptr_scales; + const Xbyak::Reg64 reg_overflow = reg_ptr_scales; + const Xbyak::Reg64 reg_icb = reg_bias; + + const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); + + const Vmm vmm_wei = Vmm(31); + /* used during bias section of store_output */ + const Vmm vmm_comp = Vmm(30); // only for signed input + const Vmm vmm_bias = Vmm(31); + /* used during post_op sum section of store_output */ + const Vmm vmm_prev_dst = Vmm(31); + /* used during write-out section of store_output */ + const Vmm vmm_zero = Vmm(31); + + /* used in compute_ker (but set during prepare_output) */ + const Vmm vmm_shift = vmm_comp; // only for signed input + /* used in compute_ker (but only for pre-VNNI machines) */ + const Vmm vmm_tmp = Vmm(28); // not used for depthwise + const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise. + + /* registers use only for depthwise + groups are always blocked by 16(padded if needed), + hence use only Zmm registers */ + const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + Xbyak::Zmm zmm_tmp; + Xbyak::Zmm zmm_src; + Xbyak::Zmm zmm_shifted_zero; + Xbyak::Zmm zmm_permute; + + Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Vmm(idx); + } + Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Xbyak::Zmm(idx); + } + Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + Vmm vmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Vmm(nb_c_block * jcp.ur_w); + } + Xbyak::Xmm xmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Xbyak::Xmm(nb_c_block * jcp.ur_w); + } + int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } + + bool maybe_eltwise(int position); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block_flag); + void compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded); + void compute_ker(int ur_w, int pad_l, int pad_r, + ic_block_t last_ic_block_flag, bool h_padded = false); + void compute_eltwise(int ur_w); + void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); + void icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); + void generate(); + void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, + bool mask_flag); + const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false); +}; + +struct jit_avx512_core_x8s8s32x_fwd_kernel { + + jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + ymm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 8: + ymm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = ymm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_core_x8s8s32x_fwd_kernel() { + delete xmm_kernel_; + delete ymm_kernel_; + delete zmm_kernel_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + void (*jit_ker)(jit_conv_call_s *); + _jit_avx512_core_x8s8s32x_fwd_kernel *zmm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel *ymm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel *xmm_kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp new file mode 100644 index 0000000000..cdbf333d5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, + nb_groups, n, jcp.mb); + break; + case loop_gncw: + nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + gg, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.nb_ic * jcp.ic_block; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.dst = dst + dst_d.blk_off(n, g_oc, ow_s); + p.src = src + src_d.blk_off(n, g_ic, iw_s); + p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.oc_blocks = jcp.is_depthwise ? gb : ocb; + p.kh_padding = jcp.kh; + p.t_overflow = 0; + p.b_overflow = 0; + p.owb = owb; + + kernel_->jit_ker(&p); + + ++start; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups, + n, jcp.mb); + break; + case loop_gncw: + nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, + nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ch_block == 1); + assert(jcp.nb_ch_blocking == 1); + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; + int nb_groups = jcp.nb_ch; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + occ, oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; + occ1 += jcp.nb_oc_blocking) { + int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + + int g_ic = g * jcp.nb_ic * jcp.ic_block; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + if (jcp.loop_order == loop_nhwcg) + oh_e = oh_s + 1; // step instead + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias + ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) + : 0; + int32_t *compensation_w = (jcp.signed_input) + ? compensation + g_oc : 0; + + auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + + for (int oj = oh_s, ij = ih_s; oj < oh_e; + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, + div_up(max(0, -ij), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = (!jcp.signed_input) + ? i_t_overflow * wht_h_stride : 0; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = ocb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + src_w += src_h_stride * jcp.stride_h; + dst_w += dst_h_stride; + } + } + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + ++start; + nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, + oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template +void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ic_block == 1); + assert(jcp.oc_block == 1); + assert(jcp.nb_ic == 1); + assert(jcp.nb_oc == 1); + assert(jcp.nb_oc_blocking == 1); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast(&w[offset]) : 0; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + + parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, + [&](int n, int oh_s, int owb, int gg) { + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; + int32_t *compensation_w = jcp.signed_input ? compensation + g : 0; + + auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, gb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g]; + + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, + div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = gb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + }); +} + +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::f32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::f32>; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp new file mode 100644 index 0000000000..203ebdf942 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""), + jit_avx512_core_x8s8s32x_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type, + data_type::f32, data_type::s32, data_type::s8, + data_type::u8)) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + const auto &_pd = pd(); + if (_pd->ndims() == 3) + execute_forward_1d(ctx); + else if (_pd->jcp_.is_depthwise) + execute_forward_2d_dw(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_2d_dw(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp new file mode 100644 index 0000000000..142af1f541 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -0,0 +1,1034 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_x8s8s32x_deconvolution.hpp" + +#define GET_OFF(field) offsetof(jit_deconv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +using namespace nstl; + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \ + (d).blk_off(__VA_ARGS__)) + +status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, const bool with_bias, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper bias_d(&bias_md); + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + jcp.signed_input = src_d.data_type() == data_type::s8; + const int ndims = jcp.ndims = dst_d.ndims(); + const bool is_1d = ndims == 3; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; + jcp.is_depthwise = true && with_groups + && utils::everyone_is(1, jcp.ic_without_padding, + jcp.oc_without_padding); + + /* TODO: future work, on hold until depthwise specialized kernel is + * implemented. */ + if (jcp.is_depthwise && jcp.signed_input) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + + format_tag_t wei_tag = is_1d + ? (jcp.is_depthwise + ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i)) + : (jcp.is_depthwise + ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i)); + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input && !jcp.is_depthwise) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + jcp.with_bias = with_bias; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.prop_kind = cd.prop_kind; + jcp.mb = src_d.dims()[0]; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.oc_block = 1; + jcp.ic_block = 1; + } else { + jcp.ch_block = 1; + jcp.oc_block = 16; + jcp.ic_block = 16; + + if (jcp.ngroups == 1) { + jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); + jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); + } + if (jcp.ic % jcp.ic_block != 0) + return status::unimplemented; + } + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) + || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) + return status::unimplemented; + + /* padding: bottom and right */ + jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.oh + jcp.t_pad - 1); + jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.ow + jcp.l_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.dst_dt = dst_d.data_type(); + jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + /* kernel blocking params */ + const int regs = jcp.ver == ver_vnni ? 30 : 28; + jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc % jcp.nb_oc_blocking == 0 + && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) + break; + + jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + + if (jcp.ow < jcp.ur_w) { + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = 0; + } else { + for (; jcp.ur_w >= 1; jcp.ur_w--) { + /* ur_w should be multiple of stride_w in order + to simplify logic for get_ow_start and get_ow_end */ + bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; + + /* boundary conditions: + These conditions ensure all elements close to boundary + are computed in a single call of compute loop */ + bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + int r_overflow_no_tail + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + bool right_boundary_covered + = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; + + if (is_multiple_of_stride && left_boundary_covered + && right_boundary_covered) + break; + else if (jcp.ur_w == 1) + /* The boundary conditions above are also important + to maintain simplicity of calls to icb_loop, + if those conditions are not satisfied, + then special cases will need to be added + to use correct l_overflow/r_overflow values + when different iterations of compute loop + work on the locations close to boundary. + So to keep code simple, return unimplemented + for extreme case when a good ur_w cannot be found. + */ + return status::unimplemented; + } + } + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; + return status::success; +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w); +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: + return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, + bool h_padded) { + + const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w; + + auto src_offset = [=](int oj, int icb, int ki) { + return jcp.typesize_in + * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) + * jcp.ngroups * jcp.ic_without_padding + + icb * 4); + }; + + auto kernel_offset = [=](int ocb, int icb, int ki) { + return jcp.typesize_in + * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all + + icb * jcp.oc_block * jcp.ic_block / 4 + + ki * ch_block_all); + }; + + auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else if (jcp.is_depthwise) { + vpmulld(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + for (int ki = 0; ki < jcp.kw; ki++) { + + int jj_start = get_ow_start(ki, l_overflow); + int jj_end = get_ow_end(ur_w, ki, r_overflow); + + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + + int tail_size = jcp.ic_without_padding % 4; + int n_ic_blocks = jcp.is_depthwise ? + 1 : + (last_ic_block_flag & ~no_last_block ? + div_up(jcp.ic_without_padding % jcp.ic_block, + 4) : + jcp.ic_block / 4); + + for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Zmm inp = zmm_inp(0, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } else { + + for (int jj = _start; jj < _end; jj += ur_w_stride) { + + int aux_src_off = src_offset(jj, icb1, ki); + + if (jj >= jj_start && jj < jj_end + && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) { + if (jcp.is_depthwise) { + vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } else if ((last_ic_block_flag & last_sp_block) + && tail_size != 0 && icb1 == n_ic_blocks - 1) { + xmm_t xmm_tmp = xmm_t( + zmm_inp(jj, jcp.nb_oc_blocking).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_src + aux_src_off + r], r); + vpbroadcastd( + zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp); + } else { + vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } + if (jcp.signed_input) + vpsubb(zmm_inp(jj, jcp.nb_oc_blocking), + zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } + } + } + } + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + int aux_filt_off = kernel_offset(ocb, icb1, ki); + + if (_end - _start > 0) { + if (jcp.is_depthwise) + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + else + vmovups(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + } + for (int jj = _start; jj < _end; jj += ur_w_stride) { + Zmm inp = (h_padded == true) ? + zmm_inp(0, jcp.nb_oc_blocking) : + zmm_inp(jj, jcp.nb_oc_blocking); + compute(zmm_out(jj, ocb), zmm_wei, inp); + } + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ngroups * jcp.ic_without_padding; + const int stride_h = jcp.signed_input ? 1 : jcp.stride_h; + int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h; + + Label kh_loop_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, b_overflow_label, + no_b_overflow_label; + + mov(aux_reg_src, reg_src); + mov(aux_reg_filt, reg_filt); + + if (jcp.signed_input && jcp.ndims > 3) { + /* Weights are transposed, so first compute 'bottom' padding. */ + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } + + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + if (jcp.signed_input || ((!jcp.signed_input) + && ((min(jcp.t_pad, jcp.b_pad) < 0) + || ((jcp.kh - 1) * (jcp.dilate_h + 1) + < nstl::max(jcp.t_pad, jcp.b_pad))))) { + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + } + + L(kh_loop_label); { + compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); + sub(aux_reg_src, shift_src_ih); + add(aux_reg_filt, shift_filt_kh); + dec(reg_kh); + + /* Insert weight compensation in stride 'holes' */ + if (jcp.signed_input && jcp.stride_h > 1) { + Label kh_comp_loop; + + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + mov(reg_comp_strides, jcp.stride_h - 1); + L(kh_comp_loop); + { + compute_ker( + ur_w, 0, 0, last_ic_block_flag, true); + add(aux_reg_filt, shift_filt_kh); + dec(reg_comp_strides); + cmp(reg_comp_strides, 0); + jg(kh_comp_loop, T_NEAR); + } + } + cmp(reg_kh, 0); + jg(kh_loop_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) { + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vpxord(zmm, zmm, zmm); + } + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps( + data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( + int ur_w, bool last_oc_block) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale + = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + int scale_offset + = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); + + auto zmm_bias = zmm_tmp; + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag); + } + + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vcvtdq2ps(zmm, zmm); + if (jcp.signed_input) + vaddps(zmm, zmm, zmm_comp); + if (jcp.with_bias) + vaddps(zmm, zmm, zmm_bias); + zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm; + vmulps(mask_zmm, zmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + if (maybe_eltwise(0)) + compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + const bool mask_flag + = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * jcp.oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_dst, aux_output_offset); + Zmm zmm = zmm_out(j, k); + cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(zmm, zmm_prev_dst); + else + vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) + compute_eltwise(ur_w); + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(zmm, zmm_zero, zmm); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + } + for (int ur = 0; ur < ur_w; ur++) { + int aux_dst_off = jcp.typesize_out + * (ur * jcp.ngroups * jcp.oc_without_padding + + ocb * jcp.oc_block); + auto addr = EVEX_compress_addr(reg_dst, aux_dst_off); + + zmm_t zmm = zmm_out(ur, ocb); + zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm; + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_zmm); break; + case data_type::s8: vpmovsdb(addr, r_zmm); break; + case data_type::u8: vpmovusdb(addr, r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop( + int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { + + int shift_src_icb = jcp.typesize_in * jcp.ic_block; + int shift_filt_icb + = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block; + + prepare_output(ur_w); + + Label skip_icb_loop, icb_loop_label; + + mov(reg_icb, jcp.nb_ic); + L(icb_loop_label); { + + if (jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + cmp(reg_icb, 1); + jg(common_ker, T_NEAR); + + kh_loop(ur_w, l_overflow, r_overflow, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + } + + add(reg_src, shift_src_icb); + add(reg_filt, shift_filt_icb); + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_loop_label, T_NEAR); + } + + /* come-back pointers */ + sub(reg_src, jcp.nb_ic * shift_src_icb); + sub(reg_filt, jcp.nb_ic * shift_filt_icb); + L(skip_icb_loop); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - 1); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + jne(common_store, T_NEAR); + + store_output(ur_w, true); + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + + } else { + store_output(ur_w, false); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise ? + jcp.ngroups % jcp.ch_block : + jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_nur_w.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + mov(reg_src, ptr[param1 + GET_OFF(src)]); + mov(reg_filt, ptr[param1 + GET_OFF(filt)]); + mov(reg_dst, ptr[param1 + GET_OFF(dst)]); + + int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups + * jcp.oc_without_padding; + int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups + * jcp.ic_without_padding; + + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + int r_overflow + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)) + / jcp.stride_w); + + int r_overflow1 + = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + int nur_w = jcp.ow / jcp.ur_w; + if (r_overflow1 > 0) + nur_w--; + + if (jcp.ur_w == jcp.ow) { + icb_loop(jcp.ur_w, l_overflow, r_overflow, true); + } else if (nur_w == 0) { + icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + if (jcp.ur_w_tail != 0) + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } else { + xor_(reg_nur_w, reg_nur_w); + if (l_overflow > 0) { + icb_loop(jcp.ur_w, l_overflow, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + } + if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + cmp(reg_nur_w, nur_w); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + int n{ 0 }, g{ 0 }, occ{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + + p.dst = dst + dst_d.blk_off(n, g_oc); + p.src = src + src_d.blk_off(n, g_ic); + p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); + p.bias = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.t_overflow = 0; + p.b_overflow = 0; + p.kh_padding = jcp.kh; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + + kernel_->jit_ker(&p); + + ++start; + if (jcp.loop_order == loop_ngc) + nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + } + }); +} + +template +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + /*loop order = cgn*/ + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, + oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + int work_rem = end - start; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + auto dst_w = dst + dst_d.blk_off(n, g_oc); + auto src_w = src + src_d.blk_off(n, g_ic); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + auto bias_w = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + int32_t *compensation_w + = (jcp.signed_input) ? compensation + g_oc : 0; + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + for (int oj = oh_s; oj < oh_e; oj++) { + int ih_max = 0, kh_lo = 0, kh_len = 0; + if (jcp.dilate_h != 0 && jcp.stride_h == 1) { + /* dilation */ + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int o_t_overflow = div_up( + max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), + dilate_h); + int o_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh + + oj - jcp.b_pad), + dilate_h); + kh_len = jcp.kh - o_t_overflow - o_b_overflow; + kh_lo = o_b_overflow; + ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; + } else { + int o_t_overflow = max( + 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); + int o_b_overflow + = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) + / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 + - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h; + int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; + + kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h + + 1 - o_t_overflow - o_b_overflow; + kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; + ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; + } + + int wei_stride + = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0; + p.src = src_w + ih_max * src_h_stride; + p.dst = dst_w + oj * dst_h_stride; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.t_overflow = max( + 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h + + 1)); + p.b_overflow = kh_lo; + p.kh_padding = kh_len; + p.scales = scales; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + kernel_->jit_ker(&p); + } + if (jcp.loop_order == loop_ngc) + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, + jcp.mb, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + }); +} + +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp new file mode 100644 index 0000000000..901038fa48 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp @@ -0,0 +1,237 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "cpu_primitive.hpp" +#include "cpu_memory.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "nstl.hpp" + +#include "cpu_deconvolution_pd.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + no_last_block = 0x1U, + last_ic_block = 0x2U, + last_sp_block = 0x4U, +} ker_block_t; + +struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); + + jit_avx512_core_x8s8s32x_deconv_fwd_kernel( + const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32( + this, jcp.eltwise); + generate(); + jit_ker = (void (*)(jit_deconv_call_s *))getCode(); + } + + ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const deconvolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &weights_md, + memory_desc_t &dst_md, + const bool with_bias, + memory_desc_t &bias_md, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + const jit_conv_conf_t &jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_deconv_call_s *); +private: + jit_uni_eltwise_injector_f32 *eltwise_injector_; + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_src = r8; + reg64_t reg_filt = r9; + reg64_t reg_dst = r10; + reg64_t param1 = abi_param1; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_nur_w = rbx; + reg64_t reg_bias = rdx; + reg64_t reg_icb = reg_bias; + reg64_t reg_ptr_scales = rax; + reg64_t reg_oc_blocks = rsi; + + reg64_t aux_reg_src = r11; + reg64_t aux_reg_filt = r12; + + reg64_t reg_compensation = r14; + reg64_t reg_scratch = r14; + reg64_t reg_ptr_sum_scale = r11; + reg64_t reg_bias_alpha = abi_not_param1; + reg64_t reg_overflow = rax; + reg64_t reg_comp_strides = reg_overflow; + + Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + zmm_t zmm_tmp = zmm_t(28); + zmm_t zmm_one = zmm_t(29); + /* used during write-out section of store_output */ + zmm_t zmm_zero = zmm_t(31); + zmm_t zmm_wei = zmm_t(31); + + /* signed input */ + zmm_t zmm_shift = zmm_t(30); + zmm_t zmm_comp = zmm_t(30); + zmm_t zmm_bias = zmm_t(31); + zmm_t zmm_prev_dst = zmm_t(31); + + zmm_t zmm_out(int i_ur, int i_oc) { + int idx = i_ur * jcp.nb_oc_blocking + i_oc; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_bias_alpha() { + return zmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + xmm_t xmm_bias_alpha() { + return xmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + + int get_ow_start(int ki, int l_overflow) { + int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return res; + } + + int get_ow_end(int ur_w, int ki, int r_overflow) { + if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return ur_w - res; + } + bool maybe_eltwise(int position); + void compute_eltwise(int ur_w); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block); + void compute_ker(int ur_w, int l_overflow, int r_overflow, + ker_block_t last_ic_block_flag, bool h_padded = false); + void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); + void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +template +struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""), + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && (desc()->alg_kind & alg_kind::deconvolution_direct) + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + with_bias(), bias_md_, *attr()); + + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status::success; + } + + jit_conv_conf_t jcp_; + }; + + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if(pd()->ndims() == 3) + execute_forward_1d(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp new file mode 100644 index 0000000000..c09592d5c9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp @@ -0,0 +1,773 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_GENERATOR_HPP +#define CPU_JIT_AVX2_GENERATOR_HPP + +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_utils/jit_utils.hpp" + +#if defined(_WIN32) && !defined(__GNUC__) +# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__ +#else +# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al))) +#endif + +#if defined(_WIN32) +# define OFFSET_SHADOWSPACE 0x28 +#endif + +#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \ + const char *name() const override { return STRINGIFY(jit_name); } \ + const char *source_file() const override { return __FILE__; } + +namespace mkldnn { +namespace impl { +namespace cpu { + +// TODO: move this to jit_generator class? +namespace { + +typedef enum { + PAGE_4K = 4096, + PAGE_2M = 2097152, +} cpu_page_size_t; + +// TODO: move this somewhere else? Although this is only used by jit kernels +// (Roma) +static inline int float2int(float x) { + union { + float vfloat; + int vint; + } cvt; + cvt.vfloat = x; + return cvt.vint; +} + +// TODO: A GPR class that hides ABI details from the JIT kernels and allows +// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and +// stack register (sr). +// +// This will allow using syntax like this: +// +// param = gpr0; +// reg_input = gpr0; +// reg_output = gpr1; +// ... +// +// #ifndef XBYAK64 +// mov(param, ptr[sr]) +// #endif +// +// (Roma) + +#ifdef XBYAK64 +constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = { + Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, + Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15, +#ifdef _WIN32 + Xbyak::Operand::RDI, Xbyak::Operand::RSI, +#endif +}; + +#ifdef _WIN32 +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX), + abi_param2(Xbyak::Operand::RDX), + abi_param3(Xbyak::Operand::R8), + abi_param4(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RDI); +#else +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI), + abi_param2(Xbyak::Operand::RSI), + abi_param3(Xbyak::Operand::RDX), + abi_param4(Xbyak::Operand::RCX), + abi_param5(Xbyak::Operand::R8), + abi_param6(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RCX); +#endif +#endif + +inline unsigned int get_cache_size(int level, bool per_core = true){ + unsigned int l = level - 1; + // Currently, if XByak is not able to fetch the cache topology + // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core. + if (cpu.getDataCacheLevels() == 0){ + const int L1_cache_per_core = 32000; + const int L2_cache_per_core = 512000; + const int L3_cache_per_core = 1024000; + int num_cores = per_core ? 1 : mkldnn_get_max_threads(); + switch(l){ + case(0): return L1_cache_per_core * num_cores; + case(1): return L2_cache_per_core * num_cores; + case(2): return L3_cache_per_core * num_cores; + default: return 0; + } + } + if (l < cpu.getDataCacheLevels()) { + return cpu.getDataCacheSize(l) + / (per_core ? cpu.getCoresSharingDataCache(l) : 1); + } else + return 0; +} + +} + +class jit_generator : public Xbyak::CodeGenerator +{ +private: + const size_t xmm_len = 16; +#ifdef _WIN32 + const size_t xmm_to_preserve_start = 6; + const size_t xmm_to_preserve = 10; +#else + const size_t xmm_to_preserve_start = 0; + const size_t xmm_to_preserve = 0; +#endif + + const size_t num_abi_save_gpr_regs + = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); + + const size_t size_of_abi_save_regs + = num_abi_save_gpr_regs * rax.getBit() / 8 + + xmm_to_preserve * xmm_len; + +public: + enum { + _cmp_eq_oq = 0u, + _cmp_lt_os = 1u, + _cmp_le_os = 2u, + _cmp_neq_uq = 4u, + _cmp_nlt_us = 5u, + _cmp_nle_us = 6u, + + _op_floor = 1u, + _op_mxcsr = 4u, + }; + + Xbyak::Reg64 param1 = abi_param1; + const int EVEX_max_8b_offt = 0x200; + const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; + + inline size_t get_size_of_abi_save_regs() { + return size_of_abi_save_regs; + } + + void preamble() { + if (xmm_to_preserve) { + sub(rsp, xmm_to_preserve * xmm_len); + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i)); + } + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + push(Xbyak::Reg64(abi_save_gpr_regs[i])); + if (mayiuse(avx512_common)) { + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + } + } + + void mic_prefetcht0(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht0(a); + } + + void mic_prefetcht1(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht1(a); + } + + void mic_prefetcht2(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht2(a); + } + + void uni_vzeroupper() { + if (mayiuse(avx) && !mayiuse(avx512_mic)) + vzeroupper(); + } + + void postamble() { + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i])); + if (xmm_to_preserve) { + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]); + add(rsp, xmm_to_preserve * xmm_len); + } + uni_vzeroupper(); + ret(); + } + + template + Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, + T raw_offt, bool bcast = false) + { + using Xbyak::Zmm; + using Xbyak::Reg64; + using Xbyak::Address; + using Xbyak::RegExp; + + assert(raw_offt <= INT_MAX); + auto offt = static_cast(raw_offt); + + int scale = 0; + + if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { + offt = offt - 2 * EVEX_max_8b_offt; + scale = 1; + } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { + offt = offt - 4 * EVEX_max_8b_offt; + scale = 2; + } + + auto re = RegExp() + base + offt; + if (scale) + re = re + reg_EVEX_max_8b_offt * scale; + + if (bcast) + return zword_b [re]; + else + return zword [re]; + } + + Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt, + const Xbyak::Reg64 &tmp_reg, bool bcast = false) { + if (offt > INT_MAX) { + mov(tmp_reg, offt); + return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg]; + } else { + return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt]; + } + } + + Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base, + size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) { + if (raw_offt > INT_MAX) { + return make_safe_addr(base, raw_offt, reg_offt, bcast); + } else { + return EVEX_compress_addr(base, raw_offt, bcast); + } + } + + void safe_add(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + add(base, reg_offt); + } else { + add(base, raw_offt); + } + } + + void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + sub(base, reg_offt); + } else { + sub(base, raw_offt); + } + } + + // Disallow char-based labels completely + void L(const char *label) = delete; + void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } + + void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + pxor(x2, op); + } + void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpxor(x1, x2, op); + } else { + vxorps(x1, x2, op); + } + } + void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, + const Xbyak::Operand &op) { + vpxord(x1, x2, op); + } + + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movss(addr, x); + } + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovss(addr, x); + } + void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movss(x, addr); + } + void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovss(x, addr); + } + + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movsd(addr, x); + } + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovsd(addr, x); + } + void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movsd(x, addr); + } + void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovsd(x, addr); + } + + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) { + vmovdqu32(addr, x); + } + + void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) { + movdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { + vmovdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) { + vmovdqu32(x, addr); + } + + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movups(addr, x); + } + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovups(addr, x); + } + + void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movups(x, op); + } + void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vmovups(x, op); + } + + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movntps(addr, x); + } + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovntps(addr, x); + } + + void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movss(x, op); + shufps(x, x, 0x0); + } + void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (op.isMEM() || mayiuse(avx2)) { + vbroadcastss(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movss(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movsd(x, op); + pshufd(x, x, 0x0); + } + void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpbroadcastd(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movsd(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpss(x, op); + } + void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { + Xbyak::Xmm x1_(x1.getIdx()); + Xbyak::Xmm x2_(x2.getIdx()); + vrcpss(x1_, x1_, x2_); + } + void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) { + Xbyak::Xmm x_(x.getIdx()); + vrcpss(x_, x_, op); + } + + void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpps(x, op); + } + void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vrcpps(x, op); + } + void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) { + vrcp14ps(x, op); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + divps(x, op2); + } + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vdivps(x, op1, op2); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + divps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vdivps(x, op1, op2); + } + + void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + addps(x, op2); + } + void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vaddps(x, op1, op2); + } + + void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, + const Xbyak::Operand& op) { + assert(x1.getIdx() == x2.getIdx()); + psignd(x1, op); + } + void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2, + const Xbyak::Operand& op) { + vpsignd(x1, x2, op); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + subps(x, op2); + } + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vsubps(x, op1, op2); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + subps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vsubps(x, op1, op2); + } + + void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + mulps(x, op2); + } + void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmulps(x, op1, op2); + } + + void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x1, x2); + addps(x1, op); + } + void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd213ps(x1, x2, op); + } + + void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + addps(x1, x2); + } + void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd231ps(x1, x2, op); + } + + void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + subps(x1, x2); + } + + void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfnmadd231ps(x1, x2, op); + } + + void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + sqrtps(x, op); + } + void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vsqrtps(x, op); + } + + void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + paddd(x2, op); + } + void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + vpaddd(x1, x2, op); + } + + void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + andps(x1, op); + } + void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vandps(x1, x2, op); + else + vpandd(x1, x2, op); + } + + void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + orps(x1, op); + } + void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vorps(x1, x2, op); + else + vpord(x1, x2, op); + } + + void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + pslld(x, imm); + } + void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpslld(x, op, imm); + } + + void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + psrld(x, imm); + } + void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpsrld(x, op, imm); + } + + void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + maxps(x, op2); + } + void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmaxps(x, op1, op2); + } + + void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + minps(x, op2); + } + void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vminps(x, op1, op2); + } + + void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nle_us); + } + + void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpgtps(x1, x2, op); + } + + void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nlt_us); + } + + void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpps(x1, x2, op, _cmp_nlt_us); + } + + void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { + ptest(x1, op); + } + + void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { + assert(!(x1.isZMM() || op.isZMM())); + vtestps(x1, op); + } + + void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op, const Xbyak::Xmm &msk) { + assert(x1.getIdx() == x2.getIdx()); + assert(msk.getIdx() == 0); + blendvps(x1, op); + } + void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op, const Xbyak::Ymm &msk) { + vblendvps(x1, x2, op, msk); + } + + void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + roundps(x, op, imm); + } + void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vroundps(x, op, imm); + } + + void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtps2dq(x, op); + } + void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtps2dq(x, op); + } + + void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtdq2ps(x, op); + } + void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtdq2ps(x, op); + } + + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) { + movmskps(x1.cvt64(), x2); + } + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) { + vmovmskps(x1, x2); + } + + void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packssdw(x1, op); + } + void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackssdw(x1, x2, op); + } + + void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packuswb(x1, op); + } + void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackuswb(x1, x2, op); + } + + + void mul_by_const(const Xbyak::Reg &out, + const Xbyak::Reg64 &tmp, int value) { + // Generates a shift + add sequence for multiplicating contents of the + // out register by a known JIT-time value. Clobbers the tmp register. + // + // Pros compared to mul/imul: + // - does not require using known registers + // - not microcoded on Intel(R) Xeon Phi(TM) processors + // Still, there are probably a lot of cases when mul/imul is faster on + // Intel(R) Core(TM) processors. Not intended for critical path. + + // TODO: detect when overflow is emminent (Roma) + // TODO: detect when using mul/imul is a better option (Roma) + + int p = 0; // the current power of 2 + int old_p = 0; // the last seen power of 2 such that value[old_p] != 0 + + xor_(tmp, tmp); + while (value) { + if (value & 1) { + int shift = p - old_p; + if (shift) { + shl(out, shift); + old_p = p; + } + add(tmp, out); + } + value >>= 1; + p++; + } + mov(out, tmp); + } + +public: + jit_generator( + void *code_ptr = nullptr, + size_t code_size = 256 * 1024 + ) : Xbyak::CodeGenerator(code_size, code_ptr) + { + } + virtual ~jit_generator() {} + + virtual const char *name() const = 0; + virtual const char *source_file() const = 0; + + const Xbyak::uint8 *getCode() { + const Xbyak::uint8 *code = CodeGenerator::getCode(); + size_t code_size = getSize(); + jit_utils::register_jit_code(code, code_size, name(), source_file()); + return code; + } + + template const F getCode() { + return (const F)getCode(); + } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp new file mode 100644 index 0000000000..56d7f592e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp @@ -0,0 +1,481 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_PRIMITIVE_CONF_HPP +#define JIT_PRIMITIVE_CONF_HPP + +#include + +#include "common/primitive_attr.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/* convolution */ +enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; +enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, + loop_ngcw, loop_nhwcg, loop_nwcg}; +enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, + loop_brl}; +enum conv_kernel_kind_t {embd_bcast, expl_bcast}; + +enum { + FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, + FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, + FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, + FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, + FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, + FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips + loading weights-data from memory; this + needs to happen on the first Group/16 + iteration. */ + FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip + loading bias data from memory */ + FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution + pass */ +}; + +struct jit_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + conv_loop_order_t loop_order; + + int simd_w; + int ndims; + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int id, ih, iw, od, oh, ow; + int f_pad, l_pad, t_pad; + int back_pad, r_pad, b_pad; + int kd, kh, kw; + int stride_d, stride_h, stride_w; + int dilate_d, dilate_h, dilate_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + + int idp, ihp, iwp, ohp, owp; + int nb_ic, ic_block; + int nb_oc, oc_block; + int nb_ow, ow_block; + int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking + into account vector registers distribution */ + int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work + within threads */ + int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work + int nb_ic_L2; + int h_blocking; + int nb_oc_L2; + int ur_h, ur_w; + int ur_w_tail; + bool is_1stconv; + int nonblk_group_off; + /* fma avx512_core */ + conv_kernel_kind_t kernel_kind; + /* 4fma */ + int tr_iw; + int tr_src_num_guard_elems; + /* 1st conv: 4fma */ + int tr_ld; + int kh_step; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* avx512_u8s8u8 */ + int ic_nb1, ic_nb2; + int oc_nb1; + int ur_ow_max, ur_ow, ur_ow_tail; + int ur_ow_nsteps; + data_type_t bia_dt; + data_type_t dst_dt; + /* avx512: max possible value is nregs(32) - aux_regs(4) */ + int src_offsets[28]; + int src_count; + bool expl_bcast; + bool large_spatial; + int is_oc_scale; + int max_regs_ur; // maximum accumulation registers + // dw conv + int nb_ch, ch_block, nb_ch_blocking; + bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; + int aligned_threads; + // large spatial + int oh_blk_size; + // s8s8 convolution + bool signed_input; + float wei_adj_scale; +}; + +struct jit_conv_conf_2x3_wino_t { + conv_version_t ver; + + int m; + int r; + int alpha; + int tile_h, tile_w; + + int mb; + int ngroups, ic, oc, oc_without_padding; + int ih, iw, oh, ow; + int l_pad, t_pad; + int r_pad, b_pad; + int kh, kw; + int stride_h, stride_w; + int dilate_h, dilate_w; + + int nb_ic, ic_block; + int nb_oc, oc_block; + + int w_block_size, h_block_size; + + data_type_t bia_dt; + data_type_t dst_dt; + + int is_oc_scale; + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + + format_tag_t src_tag, dst_tag; // temporary workaround + bool with_bias; + bool small_mb; + + int xb, yb; + int inp_stride; + int out_stride; + int wei_stride; + int bia_stride; + + int M, N, K; + int m_block, n_block, k_block; + int n2_block, n_chunks; + int k2_block, k_chunks; + + int mb_block, nb_mb; + + size_t size_wino_src, size_wino_wei, size_wino_dst; + + int nthr; +}; + +/* + Winograd sched policy: + + Computation Unit: + W: weights transform + S: src transform + D: dst transform + G: gemm + + Thread grouping by: + i: nb_ic + o: nb_oc + t: tile_block + e: element in tile + + Note: 'i' and 'o' are omited if + i. not comblined with t or + ii. with discrete transforms + + Current policies supported: +*/ +enum winograd_sched_t { + WSCHED_INVALID = 0, + + /* Forward & backward-data */ + /* W_S_G_D implements discrete transforms */ + WSCHED_DATA_W_S_G_D, + /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ + WSCHED_DATA_W_SGD, + + /* Backward-weights */ + WSCHED_WEI_S_D_G_W, + WSCHED_WEI_SDGtWo, + WSCHED_WEI_S_D_Giot_W, + WSCHED_WEI_SDGt_W, +}; + +struct jit_conv_winograd_conf_t : public jit_conv_conf_t { + int itiles; + int jtiles; + int ntiles; + int ic_simd_block=16; + int tile_4fma_padding; + int tile_4fma; + int oc_simd_block=16; + int oc_reg_block; + int ic_reg_block; + int tile_block; + int tile_block_ur; + int nb_tile_block_ur; + + bool double_buffering; + bool with_relu_postsum; + int zmm_start; + int nb_reg; + + int dimK; + int dimK_4fma; + int dimK_reg_block; + int dimK_block; + int dimK_nb_block; + + int dimM; + int dimM_reg_block; + int dimM_simd_block; + int dimM_block; + int dimM_nb_block; + + int dimN; + int dimN_reg_block; + int dimN_bcast_ur; + int dimN_block; + int dimN_nb_block; + + winograd_sched_t sched_policy; +}; + +struct jit_conv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *src_prf; + const void *dst_prf; + const void *filt_prf; + const void *bias_prf; + const void *scales; + const void *acc_s32; + const void *compensation; + size_t kd_offset; + size_t kd_offset_prf; + size_t d_index; + size_t d_index_prf; + size_t d_worksize; + size_t d_worksize_prf; + size_t kd_padding; + size_t kd_padding_prf; + size_t kh_padding; + size_t kh_padding_prf; + size_t owb; + size_t owb_prf; + size_t kw_padding; + size_t channel; + size_t channel_prf; + size_t oc_blocks; + size_t ur_w; + size_t ur_str_w; + size_t ch_blocks; + size_t t_overflow; + size_t b_overflow; + int flags; +}; + +struct jit_deconv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *scales; + const void *compensation; + size_t t_overflow; + size_t b_overflow; + size_t kh_padding; + size_t oc_blocks; +}; + +struct jit_dw_conv_call_s { + const void *input; + const void *output; + const void *filter; + const void *bias; + size_t kh_count; + size_t oh_count; + size_t oh_index; + size_t filter_pad_off; + unsigned char + exec_flags; /* Flags passed by driver execution to inner kernel */ +}; + +struct jit_wino_transform_call_s { + size_t tile_block; + size_t tile_block_ur; + size_t nb_tile_block_ur; + size_t tile_count; + size_t tj; + size_t ti; + void *src; + void *dst; + void *Mw; + void *M; + void *T; + void *G; + void *bias; +}; + +struct jit_1x1_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int iw, ih, ow, oh; + int l_pad, t_pad; + int kh, kw; + int stride_h, stride_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int is, os; + int ic_block, oc_block; + + int ur, ur_tail; + + int reduce_dim, reduce_block, nb_reduce, + nb_reduce_blocking, nb_reduce_blocking_max; + int load_dim, load_block, nb_load, + nb_load_blocking, nb_load_blocking_max, nb_load_chunk; + int bcast_dim, bcast_block, nb_bcast, + nb_bcast_blocking, nb_bcast_blocking_max; + + int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; + int load_loop_load_step, load_loop_iter_step; + int bcast_loop_output_step, bcast_loop_output_substep; + int bcast_loop_bcast_step, bcast_loop_bcast_substep; + int fma_step; + int load_grp_count; + conv_1x1_loop_order_t loop_order; + bool use_vmovntps; + /* avx512 core */ + bool expl_bcast; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* 4fma */ + bool transpose_src; + int tr_is; + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + int is_oc_scale; + data_type_t bia_dt; + data_type_t dst_dt; + bool signed_input; + float wei_adj_scale; +}; + +struct jit_gemm_conv_conf_t { + prop_kind_t prop_kind; + + int mb; + int ngroups, ic, oc; + int iw, ih, id, ow, oh, od; + int l_pad, t_pad, f_pad; + int kh, kw, kd; + int stride_h, stride_w, stride_d; + int dilate_h, dilate_w, dilate_d; + bool with_bias; + + int is, os, ks; + int ic_block, oc_block; + + int nthr; + ptrdiff_t im2col_sz; + bool need_wei_reduction; + bool signed_input; + int oh_block; + int ow_block; + bool outer_threading; +}; + +struct jit_1x1_conv_call_s { + const void *bcast_data; + const void *load_data; + const void *output_data; + const void *bias_data; // used in forward and backward_weights only + const void *acc_s32; + const void *scales; + const void *compensation; + + size_t load_dim; + size_t bcast_dim; + size_t reduce_dim; + + size_t output_stride; // used in backward_weights only + + size_t first_last_flag; +}; + +/* pooling */ +struct jit_pool_conf_t { + int ndims; + int mb, c; + int id, ih, iw, od, oh, ow; + int stride_d, stride_h, stride_w; + int kd, kh, kw; + int f_pad, t_pad, l_pad; + alg_kind_t alg; + bool is_training; + bool pad_w_is_null; + bool is_backward; + bool simple_alg; + data_type_t ind_dt; + + int c_block, c_tail, nb_c; + int ur_c, ur_c_tail; + int ur_w; + int ur_w_tail; + size_t tail[4]; + data_type_t src_dt; + data_type_t dst_dt; +}; + +struct jit_pool_call_s { + const float *src; + const float *dst; + const void *indices; + const float *src_prf; + const float *dst_prf; + const void *indices_prf; + size_t oh; + size_t kd_padding; + size_t kh_padding; + size_t kh_padding_shift; + size_t kd_padding_shift; + size_t kw_padding; + const float* init_value; + float ker_area_h; +}; + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp new file mode 100644 index 0000000000..94d2101d6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp @@ -0,0 +1,677 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto reg_load = [=](int i, int n) { + return Xmm(2*ur * load_loop_blk + 2*i + n + 1); + }; + + auto reg_accum = [=](int i, int j, int n) { + return Xmm(2*j * load_loop_blk + 2*i + n + 1); + }; + + auto bias_ptr = [=](int i, int n) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i, int n) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)]; + }; + + auto output_ptr = [=](int i, int j, int n) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) { + movups(reg_accum(i, j, 0), bias_ptr(i, 0)); + movups(reg_accum(i, j, 1), bias_ptr(i, 1)); + } + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + xorps(r0, r0); + xorps(r1, r1); + } + + L(init_done); + + // load weights + for (int i = 0; i < load_loop_blk; ++i) { + movups(reg_load(i, 0), load_ptr(0, i, 0)); + movups(reg_load(i, 1), load_ptr(0, i, 1)); + } + + movss(reg_bcast, bcast_ptr(0, 0)); + shufps(reg_bcast, reg_bcast, 0); + }; // init() + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + addps(r0, output_ptr(i, j, 0)); + addps(r1, output_ptr(i, j, 1)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(1, + 2 * ur * load_loop_blk + 1); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + movups(output_ptr(i, j, 0), reg_accum(i, j, 0)); + movups(output_ptr(i, j, 1), reg_accum(i, j, 1)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + mulps(reg_load(i, 0), reg_bcast); + mulps(reg_load(i, 1), reg_bcast); + addps(reg_accum(i, j, 0), reg_load(i, 0)); + addps(reg_accum(i, j, 1), reg_load(i, 1)); + + if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) { + movups(reg_load(i, 0), load_ptr(u + 1, i, 0)); + movups(reg_load(i, 1), load_ptr(u + 1, i, 1)); + } + } + if (j < ur - 1) { + movss(reg_bcast, bcast_ptr(u, j + 1)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for ur + if (!last_block || u < jcp.reduce_loop_unroll - 1) { + movss(reg_bcast, bcast_ptr(u + 1, 0)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for reduce_loop_unroll + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} // reduce_loop() + +void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i, int n) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i, int n) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = diff_bias_reg(i, 0); + auto r1 = diff_bias_reg(i, 1); + xorps(r0, r0); + xorps(r1, r1); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) { + movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0)); + movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1)); + } + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) { + addps(diff_bias_reg(i, 0), load_ptr(u, i, 0)); + addps(diff_bias_reg(i, 1), load_ptr(u, i, 1)); + } + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) { + movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0)); + movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1)); + } + + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_sse42_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) + return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 4; + jcp.ic_block = jcp.oc_block = simd_w*2; + + args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ur = 1; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp new file mode 100644 index 0000000000..b314a5098c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP +#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_conv_kernel_f32: public jit_generator { + jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_sse42_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32) + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t imm_addr64 = load_loop_iter; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + xmm_t reg_bcast = xmm_t(15); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp new file mode 100644 index 0000000000..30c137641e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp @@ -0,0 +1,134 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_1x1_convolution.hpp" +#include "utils.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void jit_sse42_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int ndims = src_d.ndims(); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + parallel(0, [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto par_conv = jit_1x1_conv_call_s(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + const int bcast_step_rem = jcp.nb_bcast - osb; + int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max + ? bcast_step_rem : jcp.nb_bcast_blocking; + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int ow = os % jcp.ow; + const int oh = os / jcp.ow; + const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0); + const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0); + + par_conv.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step_rem = jcp.nb_load - ocb; + const int load_step = load_step_rem < jcp.nb_load_blocking_max + ? load_step_rem : jcp.nb_load_blocking; + + const size_t _ocb = g * nb_oc + ocb; + par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + par_conv.output_data = &dst[dst_off]; + + par_conv.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + par_conv.first_last_flag = 0 + | (icb == 0) * FLAG_REDUCE_FIRST + | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST; + + par_conv.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking * jcp.ic_block); + + const size_t _icb = g * nb_ic + icb; + const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw); + par_conv.bcast_data = &src[src_off]; + + par_conv.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + kernel_->jit_ker(&par_conv); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp new file mode 100644 index 0000000000..b32b1e4784 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP +#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""), + jit_sse42_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_1x1_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + } + ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_1x1_conv_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp new file mode 100644 index 0000000000..17cabc1186 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp @@ -0,0 +1,497 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l); + else + inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + + for (int jj = jj_start; jj < jj_end; jj++) + { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * ker_off]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + } +} + +void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l); + else + inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ifm2 * oc_blk; + for (int jj = jj_start; jj < jj_end; jj++) { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ? + dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) + ? dilate_h : ic_blk * dilate_h; + const int inp_off = one_of(jcp.src_tag, ncw, nchw) + ? dilate_w : ic_blk * dilate_w; + + xor_(simd_iter, simd_iter); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + Label init_simd_iter_loop; + Label init_done; + Label init_first; + + L(init_simd_iter_loop); + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), xword[reg_output + + sizeof(float) * (ii * oh * ow + jj) * oc_blk]); + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + addps(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1)); + } + + L(init_done); + + Label skip_kh_loop; + mov(kj, reg_kh); + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.with_eltwise) { + Label regular_store; + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off = (ii * oh * ow + jj) * oc_blk; + + Xmm reg_out = Xmm(ur_w * ii + jj + 1); + movups(xword[reg_output + sizeof(float) * o_off], reg_out); + } + } + + mov(aux_reg_kernel, reg_kernel); + mov(aux_reg_input, reg_input); + add(aux_reg_kernel, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + add(reg_bias, sizeof(float) * 4); + + inc(simd_iter); + cmp(simd_iter, 2); + jl(init_simd_iter_loop, T_NEAR); + + sub(reg_output, sizeof(float) * 8); + sub(reg_bias, sizeof(float) * 8); +} + +inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" +} + +void jit_sse42_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail); + } + + L(exit); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0]; + jcp.dilate_w = cd.dilates[ndims - 3]; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + bool args_ok = true + && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o)) + && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c); + if (!args_ok) return status::unimplemented; + + const int simd_w = 8; // 2 SSE vectors processing at once + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + // kernel needs 1 temporary YMM register + const int num_avail_regs = 15; + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp new file mode 100644 index 0000000000..33c26ef081 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp @@ -0,0 +1,93 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP +#define JIT_SSE42_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_conv_fwd_kernel_f32: public jit_generator { + jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_sse42_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32) + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t simd_iter = r15; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = reg_oc_blocks; + Xbyak::Reg32 reg_ci_flag = r13d; + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void solve_common(int oc_blocks); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp new file mode 100644 index 0000000000..5f77d692f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_convolution.hpp" +#include "mkldnn_thread.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() \ + ? (f).blk_off(g, __VA_ARGS__) \ + : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kh, kw) \ + pd()->ndims() == 3 \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : wht_blk_off_(f, g, oc, ic, kh, kw) + +void jit_sse42_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + } + icbb += icb_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp new file mode 100644 index 0000000000..d2f0a38c5c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP +#define CPU_JIT_SSE42_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_sse42_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", sse42, ""), + jit_sse42_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() == 3; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_sse42_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_conv_fwd_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp new file mode 100644 index 0000000000..0e734f7265 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp @@ -0,0 +1,1192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "jit_generator.hpp" +#include "cpu_barrier.hpp" + +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +#define GET_OFF(x) offsetof(ctx_t, x) + +struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t) + + jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t k3333 = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kCCCC = k4; + opmask_t k0F0F = k5; + opmask_t kF0F0 = k6; + opmask_t kTail = k7; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src, + (transpose_size + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + int offset = (transpose_size) * typesize + i * tr_src_stride; + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset)); + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, + offset + 64)); + }; + + auto pf_src_t1 = [=](int i) { + if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf, + i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, + i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(16 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = r; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) + add(reg_tr_src_tmp, l_pad * typesize); + + if (tail != transpose_size) + kmovw(kTail, (1 << tail) - 1); + + // Xbyak does not allow k0 to be specified explicitly via the '|' + // operator, so we have to do this via a method call (implicitly + // EVEX encoding uses k0 to mean 'no mask') + bool partial_store = nrows < 16; + auto k = partial_store ? kTail : k0; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + if (nontemporal_stores && !partial_store) + vmovntps(addr, r); + else + vmovups(addr, r); + + if (r_pad > 0) { + add(reg_tr_src_tmp, tail * typesize); + padding(reg_tr_src_tmp, r_pad); + } + + if (l_pad > 0) { + padding(reg_tr_src, l_pad); + } + }; + + auto transpose16x8 = [=](int base_idx) { + assert(base_idx == 0 || base_idx == 8); + + // swap 1 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i * 2; + int src_idx1 = src_idx0 + 1; + + int next_src_idx0 = src_idx0 + 2; + int next_src_idx1 = src_idx1 + 2; + bool load_next = base_idx == 0 || i < 3; + + if (base_idx == 0 && i == 0) { + load(src_idx0); + load(src_idx1); + } + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx1); + auto src0 = src_zmm(src_idx0); + auto src1 = src_zmm(src_idx1); + + if (next_src_idx0 < nrows && load_next) + load(next_src_idx0); + valignd(tmp0, src0, src0, 0x1); + pf_src_t1(base_idx + i); + + if (next_src_idx1 < nrows && load_next) + load(next_src_idx1); + valignd(tmp1, src1, src1, 0xf); + pf_src_t0(base_idx + i); + + vmovaps(src0 | kAAAA, tmp1); + vmovaps(src1 | k5555, tmp0); + } + // swap 2 + for (int i = 0; i < 4; i++) { + int select_half = (i < 2) ? 0 : 2; + int src_idx0 = base_idx + i + select_half + 0; + int src_idx2 = src_idx0 + 2; + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx2); + auto src0 = src_zmm(src_idx0); + auto src2 = src_zmm(src_idx2); + + valignd(tmp0, src0, src0, 0x2); + pf_src_t1(base_idx + 4 + i); + valignd(tmp1, src2, src2, 0xe); + pf_src_t0(base_idx + 4 + i); + vmovaps(src2 | k3333, tmp0); + vmovaps(src0 | kCCCC, tmp1); + } + + // swap 4 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i; + int src_idx4 = src_idx0 + 4; + + auto tmp0 = tmp_zmm(src_idx0); + auto src0 = src_zmm(src_idx0); + auto src4 = src_zmm(src_idx4); + + vmovaps(tmp0, src0); + vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); + pf_tr_src_t1(base_idx / 2 + i); + vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); + pf_tr_src_t0(base_idx / 2 + i); + } + }; + + auto fixup16x16 = [=]() { + // swap 8 + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0x44); + store(tmp, i); + if (i % 2 == 0) { + pf_tr_src_t1(8 + i / 2); + pf_tr_src_t0(8 + i / 2); + } + } + + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(8 + i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0xee); + store(tmp, 8 + i); + if (i % 2 == 0) { + pf_tr_src_t1(12 + i / 2); + pf_tr_src_t0(12 + i / 2); + } + } + }; + + transpose16x8(0); + transpose16x8(8); + fixup16x16(); +} + +void jit_trans_iw_ic_t::generate() { + preamble(); + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(k3333, 0x3333); // 0011001100110011 + kmovw(k5555, 0x5555); // 0101010101010101 + kmovw(kAAAA, 0xaaaa); // 1010101010101010 + kmovw(kCCCC, 0xcccc); // 1100110011001100 + kmovw(k0F0F, 0x0f0f); // 0000111100001111 + kmovw(kF0F0, 0xf0f0); // 1111000011110000 + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t) + jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf): + jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFFFF = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kAA = k4; + opmask_t k55 = k5; + opmask_t kCC = k6; + opmask_t k33 = k7; + opmask_t kTail = k1; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + Xbyak::Zmm vidx1 = zmm31; + Xbyak::Zmm vidx2 = zmm30; + Xbyak::Zmm vidx3 = zmm29; + Xbyak::Zmm vidx4 = zmm28; + Xbyak::Zmm vidx5 = zmm27; + Xbyak::Zmm zmm_tmp = zmm26; + + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto store = [=](Zmm r, int i) { + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = zmm_tmp; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + int store_tail = (nrows%2) ? nrows+1 : nrows; + + int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2; + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) { + padding(reg_tr_src, store_pad); + add(reg_tr_src_tmp, l_pad * typesize); + } + if (r_pad > 0) { + store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2; + int addr_shift = (r_pad%2) ? 1 : 0; + add(reg_tr_src_tmp, (nrows - addr_shift) * typesize); + padding(reg_tr_src_tmp, store_pad); + } + + mov(reg_tr_src_tmp, reg_tr_src); + add(reg_tr_src_tmp, l_pad * typesize); + + kmovw(kTail, (1 << store_tail/2) - 1); + auto k = kTail; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, r); + + }; + + kmovw(kFFFF, 0xffff); + //all loads + for (int i=0; i<16; i++){ + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // for odd numbers we need to mix row with zeroes + if (nrows%2) { + int i = nrows-1; + auto src0 = src_ymm(i); + auto src1 = src_ymm(i+1); //zero + + auto zmm_src0 = src_zmm(i); + vpxor(src1, src1, src1); + + load_ymm(i); + vpunpckhwd(src0, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src0, 0); + vpxor(src0, src0, src0); + load_ymm(i); + vpunpcklwd(src1, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src1, 1); + vpxord(zmm_src0, zmm_src0, zmm_src0); + vmovups(zmm_src0, zmm_tmp); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // swap 1 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(4*i); + auto zmm1 = src_zmm(4*i+2); + auto tmp0 = src_zmm(4*i+1); + auto tmp1 = src_zmm(4*i+3); + + vmovups(tmp0, zmm0); + vmovups(tmp1, zmm1); + + vpermps(tmp0 | kAAAA, vidx3, zmm1); + vpermps(tmp1 | k5555, vidx3, zmm0); + } + // swap 2 + int base_idx; + base_idx=0; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + base_idx=8; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + + // swap 3 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(2*i); + auto zmm1 = src_zmm(2*i+8); + + auto tmp0 = src_zmm(2*i+1); + auto tmp1 = src_zmm(2*i+9); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kCC, vidx1, zmm1); + vpermpd(tmp1 | k33, vidx1, zmm0); + } + + // all stores + for (int i=0; i<8; i++) + vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1); + + store(src_zmm(1), 0); + store(src_zmm(0), 1); + store(src_zmm(3), 2); + store(src_zmm(2), 3); + store(src_zmm(9), 4); + store(src_zmm(8), 5); + store(src_zmm(11), 6); + store(src_zmm(10), 7); + store(src_zmm(5), 8); + store(src_zmm(4), 9); + store(src_zmm(7), 10); + store(src_zmm(6), 11); + store(src_zmm(13), 12); + store(src_zmm(12), 13); + store(src_zmm(15), 14); + store(src_zmm(14), 15); + +} + +void jit_trans_iw_ic_int16_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 2, 3, 0, 1, 6, 7, 4, 5 }; + alignas(64) static constexpr const int64_t idx2[8] + = { 1, 0, 3, 2, 5, 4, 7, 6 }; + alignas(64) static constexpr const int32_t idx3[16] + = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 }; + alignas(64) static constexpr const int32_t idx4[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + alignas(64) static constexpr const int32_t idx5[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFFFF, 0xffff); + kmovw(k5555, 0x5555); + kmovw(kAAAA, 0xaaaa); + kmovw(kAA, 0xaa); + kmovw(k55, 0x55); + kmovw(kCC, 0xcc); + kmovw(k33, 0x33); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + vmovdqa64(vidx2, idx2); + vmovdqa32(vidx3, idx3); + vmovdqa32(vidx4, idx4); + vmovdqa32(vidx5, idx5); + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); + +} + +struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t) + jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + using zmm = const Xbyak::Zmm; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFF = k1; + + zmm vidx1 = zmm31; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + + auto store = [=](Zmm r, int i) { + auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); + if (nontemporal_stores) + vmovntps(addr, r); + else + vmovups(addr, r); + }; + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, 2*i); + } + if (r_pad > 0) { + auto src0 = src_ymm(nrows-1); + auto src1 = src_ymm(nrows); + auto zmm_src0 = src_zmm(30); + load_ymm(nrows-1); + + vpxor(src1, src1, src1); + vpunpckhwd(src1, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src1, 0); + vpxor(src1, src1, src1); + vpunpcklwd(src0, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src0, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, nrows-1); + } +} + +void jit_trans_ow_oc_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 4, 5, 0, 1, 6, 7, 2, 3 }; + + const int oc_block = conf_->oc_block; + const int ow = conf_->ow; + const int transposes = utils::div_up(ow, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = ow - loop_iters * transpose_size; + + src_stride = oc_block * typesize; + tr_src_stride = oc_block * typesize; + + bool nontemporal_stores = false; + enable_prefetch = ow > small_spatial ? 1 : 0; + + const int src_step = oc_block * transpose_size * typesize; + const int tr_src_step = oc_block * transpose_size * typesize; + const int right_pad = ow % 2; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFF, 0xFF); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + transpose(tail, 0, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t) + + jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + + void generate(); + enum { typesize = (int)sizeof(float) }; +}; + +/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4] + * required for 1st 4fma backward by weights convolution */ +void jit_trans_iw_x4_4x_t::generate() { + using namespace utils; + + /* TODO: put into code */ + static int mask[16] = { + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, }; + + const auto &c = *conf_; + const int simd_w = cpu_isa_traits::vlen / typesize; + const int niters = c.tr_ld / simd_w; + + assert(niters <= 4); /* [bwd_w:tr_src:r1] */ + + Reg64 reg_ptr_src = r8; + Reg64 reg_ptr_tr_src = r9; + + Reg64 reg_ih = rax; + Reg64 reg_ih_end = rbx; + + Reg64 reg_nthr_oc_b = rsi; + Reg64 reg_ptr_tr_src_bctx = abi_not_param1; + + Reg64 reg_tmp = rdx; + + Zmm vmsk = Zmm(31); + Opmask kmsk = k7; + + auto emit_tr_sync = [&]() { + simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b); + }; + + auto emit_tr_iw = [&]() { + auto vreg = [](int iter, int i) { + assert(4 * iter + i < 24); + return Zmm(4 * iter + i); + }; + auto vtmp = [](int i) { return Zmm(24 + i); }; + + auto emit_load = [&](int iter) { + for (int i = 0; i < 4; ++i) { + auto v = vreg(iter, i); + const int off = (iter * 4 + i) * simd_w; + + if (off + simd_w <= c.iw) + vmovups(v, ptr[reg_ptr_src + off * typesize]); + else if (off < c.iw) + vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]); + else + vpxord(v, v, v); + } + }; + + auto emit_tr = [&](int iter) { + for (int i = 0; i < 4; ++i) + vpermps(vreg(iter, i), vmsk, vreg(iter, i)); + + vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88); + vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd); + vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88); + vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd); + + vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88); + vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd); + vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88); + vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd); + }; + + auto emit_store = [&]() { + for (int i = 0; i < 4; ++i) { + for (int iter = 0; iter < niters; ++iter) { + const size_t off = i * c.tr_ld + iter * simd_w; + vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i)); + } + } + }; + + for (int iter = 0; iter < niters; ++iter) + emit_load(iter); + + for (int iter = 0; iter < niters; ++iter) + emit_tr(iter); + + emit_store(); + }; + + preamble(); + + mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]); + + mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]); + mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]); + mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]); + mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]); + + emit_tr_sync(); + + Label l_ih_loop, l_tr_done; + cmp(reg_ih, reg_ih_end); + je(l_tr_done, T_NEAR); + + mov(reg_tmp, (size_t)&mask[0]); + vmovups(vmsk, ptr[reg_tmp]); + + if (c.iw % simd_w) { + const char load_mask = (1 << (c.iw % simd_w)) - 1; + mov(reg_tmp, load_mask); + kmovw(kmsk, reg_tmp.cvt32()); + } + + /* src += ih_start * c.iw; */ + imul(reg_tmp, reg_ih, c.iw * typesize); + add(reg_ptr_src, reg_tmp); + /* tr_src += ih_start * c.stride_w * c.tr_ld; */ + imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize); + add(reg_ptr_tr_src, reg_tmp); + + L(l_ih_loop); { + emit_tr_iw(); + + add(reg_ptr_src, c.iw * typesize); + add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize); + + inc(reg_ih); + cmp(reg_ih, reg_ih_end); + jl(l_ih_loop, T_NEAR); + } + + L(l_tr_done); + + emit_tr_sync(); + + postamble(); +} + +/* +// ------------------------------------------------- +// jit_transpose4x16_src +// ------------------------------------------------- +*/ + +void jit_transpose4x16_src::transpose(int nrows) +{ + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 4, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if (tparams->src_pf0_distance) + prefetcht0(EVEX_compress_addr( + reg_src, (tparams->src_pf0_distance + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + if (tparams->tr_src_pf0_distance) + prefetcht0(EVEX_compress_addr(reg_tr_src, + (tparams->tr_src_pf0_distance + i) * src_stride)); + }; + + auto pf_src_t1 = [=](int i) { + if (tparams->src_pf1) + prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if (tparams->tr_src_pf1) + prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(4 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r); + }; + + auto tmp0 = tmp_zmm(0); + auto tmp1 = tmp_zmm(1); + auto tmp2 = tmp_zmm(2); + auto tmp3 = tmp_zmm(3); + + auto src0 = src_zmm(0); + auto src1 = src_zmm(1); + auto src2 = src_zmm(2); + auto src3 = src_zmm(3); + for (int i = 0; i < nrows; i++) { + load(i); + } + + for (size_t i = nrows; i < 4; i++) { + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + vmovupd(tmp0, src0); + vmovupd(tmp1, src1); + pf_src_t0(0); + vpermpd(tmp0 | kF0, vidx01, src2); + vpermpd(tmp1 | kF0, vidx01, src3); + + valignd(src0, src0, src0, 8); + valignd(src1, src1, src1, 8); + pf_src_t0(1); + vmovupd(tmp2, src0); + vmovupd(tmp3, src1); + pf_src_t0(2); + vpermpd(tmp2 | kF0, vidx10, src2); + vpermpd(tmp3 | kF0, vidx10, src3); + pf_src_t0(3); + + vmovupd(src0, tmp0); + pf_src_t1(0); + vmovupd(src1, tmp2); + pf_src_t1(1); + vmovupd(src2, tmp1); + pf_src_t1(2); + vmovupd(src3, tmp3); + pf_src_t1(3); + vpermpd(src0 | kCC, vidx1, tmp1); + vpermpd(src1 | kCC, vidx1, tmp3); + pf_tr_src_t0(0); + vpermpd(src2 | k33, vidx1, tmp0); + vpermpd(src3 | k33, vidx1, tmp2); + pf_tr_src_t0(1); + + vmovupd(tmp0, src0); + vmovupd(tmp1, src2); + pf_tr_src_t0(2); + vmovupd(tmp2, src1); + vmovupd(tmp3, src3); + pf_tr_src_t0(3); + vpermps(tmp0 | kFFFF, vidxP, src0); + pf_tr_src_t1(0); + vpermps(tmp1 | kFFFF, vidxP, src2); + pf_tr_src_t1(1); + vpermps(tmp2 | kFFFF, vidxP, src1); + pf_tr_src_t1(3); + vpermps(tmp3 | kFFFF, vidxP, src3); + pf_tr_src_t1(4); + + store(tmp0, 0); + store(tmp1, 1); + store(tmp2, 2); + store(tmp3, 3); +} + +alignas(64) static constexpr const int64_t idx01[8] + = { 0, 0, 0, 0, 0, 1, 2, 3 }; +alignas(64) static constexpr const int64_t idx10[8] + = { 0, 0, 0, 0, 4, 5, 6, 7 }; +alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 }; +alignas(64) static constexpr const int32_t idxP[16] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + +void jit_transpose4x16_src::generate() +{ + preamble(); + + const int ic_block = params->ic_block; + const int is = params->is; + int tail = is % transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = ic_block * typesize; + + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * transpose_size * typesize; + +#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x) + mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]); + mov(reg_src, ptr[param1 + GET_TR_OFF(src)]); + mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]); + mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]); +#undef GET_TR_OFF + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + kmovw(kF0, 0xf0); // 11110000 + kmovw(kCC, 0xcc); // 11001100 + kmovw(k33, 0x33); // 00110011 + kmovw(kFFFF, 0xffff); // 1111111111111111 + + vmovdqa64(vidx01, idx01); + vmovdqa64(vidx10, idx10); + vmovdqa64(vidx1, idx1); + vmovdqa32(vidxP, idxP); + + Label loop_label; + Label tail_label; + + cmp(reg_loop, transpose_size); + jl(tail_label, T_NEAR); + + L(loop_label); + { + transpose(transpose_size); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, transpose_size); + cmp(reg_loop, transpose_size); + jge(loop_label, T_NEAR); + } + L(tail_label); + transpose(tail); + + postamble(); +} + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) { + if (conf->ver == ver_4fma && !conf->is_1stconv) + return new jit_trans_iw_ic_t(conf); + if (conf->ver == ver_4fma && conf->is_1stconv) + return new jit_trans_iw_x4_4x_t(conf); + assert(!"unsupported configuration"); + return nullptr; +} + +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) { + assert(!"unsupported configuration"); + return nullptr; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp new file mode 100644 index 0000000000..565e97e4fc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp @@ -0,0 +1,145 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_TRANSPOSE_SRC_HPP +#define CPU_JIT_TRANSPOSE_SRC_HPP + +#include "cpu_barrier.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_trans_src_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_src_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_src_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_src_transpose_s { + size_t size; + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; +}; + +struct jit_trans_dst_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_dst_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_dst_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_transpose4x16_src_t { + int src_pf0_distance; + int tr_src_pf0_distance; + bool src_pf1; + bool tr_src_pf1; +}; + +struct jit_transpose4x16_src : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) + + jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, + jit_transpose4x16_src_t *tparams_) + : params(aparams), tparams(tparams_) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + const jit_1x1_conv_conf_t *params; + const jit_transpose4x16_src_t *tparams; + void (*jit_ker)(jit_src_transpose_s *); + + void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } + + static const int transpose_size = 4; +private: + static const int typesize = sizeof(float); + + int src_stride, tr_src_stride; + + Xbyak::Reg64 imm_addr64 = rbx; + + Xbyak::Opmask kF0 = k1; + Xbyak::Opmask kCC = k2; + Xbyak::Opmask k33 = k3; + Xbyak::Opmask kFFFF = k4; + + Xbyak::Zmm vidx01 = zmm31; + Xbyak::Zmm vidx10 = zmm30; + Xbyak::Zmm vidx1 = zmm29; + Xbyak::Zmm vidxP = zmm28; + + Xbyak::Reg64 reg_src = r8; + Xbyak::Reg64 reg_tr_src = r9; + Xbyak::Reg64 reg_src_prf = r10; + Xbyak::Reg64 reg_tr_src_prf = r11; + Xbyak::Reg64 reg_loop = r12; + Xbyak::Reg64 reg_tr_src_tmp = r13; + Xbyak::Reg32 regw_tmp = r14d; + + void transpose_block(int ur, int nrows); + void transpose(int nrows); + void generate(); +}; + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp new file mode 100644 index 0000000000..53313f9f01 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp @@ -0,0 +1,327 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_1x1_CONV_UTILS_HPP +#define JIT_UNI_1x1_CONV_UTILS_HPP + +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +struct reduce_to_unit_stride_t { + convolution_desc_t conv_d_; + bool reduce_src_; + size_t space_per_thread_; +}; + +/* 1x1-kernel does not support non-unit strides so far, so the idea is: + * - for fwd or bwd_weights: to copy src to a scratch memory (with strides + * equal to 1) and then call the kernel + * - for bwd_data: reduce the problem to the one with unit stride by + * performing computations in a scratch memory (with strides equal to 1) + * and then copy the result to diff_src */ +template +inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d, + const memory_desc_t *&src_d, const memory_desc_t *dst_d) { + const bool is_bwd_data = self->desc()->prop_kind + == prop_kind::backward_data; + + const int ndims = src_d->ndims; + const auto dat_tag = ndims == 3 + ? memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nCw8c, format_tag::nCw16c) + : memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nChw8c, format_tag::nChw16c); + + bool rtus_applicable = true + && utils::pick(ndims - 3, + (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type, + data_type::s32)), + (conv_d->strides[0] != 1 || conv_d->strides[1] != 1)) + && dat_tag != format_tag::undef; + for (int d = 2; d < ndims; ++d) { + /* TODO: relax these conditions (by improving reducer) */ + rtus_applicable = rtus_applicable + && conv_d->padding[0][d - 2] == 0 + && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d]; + } + + if (rtus_applicable) { + self->rtus_.reduce_src_ = true; + conv_d = &(self->rtus_.conv_d_ = *conv_d); + self->rtus_.conv_d_.strides[0] = 1; + if (ndims == 4) + self->rtus_.conv_d_.strides[1] = 1; + utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2); + if (ndims == 4) + utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2); + const int ic = src_d->dims[1]; + if (is_bwd_data) { + src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d); + self->rtus_.conv_d_.diff_src_desc.dims[1] = ic; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.diff_src_desc, dat_tag); + } else { + data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type; + src_d = &(self->rtus_.conv_d_.src_desc = *dst_d); + self->rtus_.conv_d_.src_desc.dims[1] = ic; + self->rtus_.conv_d_.src_desc.data_type = data_type; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.src_desc, dat_tag); + } + } +} + +template +inline void rtus_prepare_space_info(conv_pd_t *self, + memory_tracking::registrar_t &scratchpad) { + const auto &jcp = self->jcp_; + + const int max_threads = mkldnn_get_max_threads(); + const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind, + jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking); + size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->desc())->data_type); + + self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block; + scratchpad.book(memory_tracking::names::key_conv_rtus_space, + typesize * max_threads * self->rtus_.space_per_thread_); +} + +template +struct rtus_driver_t: public jit_generator { + + struct call_params_t { + const void *ws; /* reduced image (w/ strides = 1) */ + const void *src; /* source image (w/ non-unit strides) */ + size_t icb; + size_t os; + size_t iw_start; + }; + + void (*ker_)(const call_params_t *p); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + + Xbyak::Reg64 reg_ws = abi_param1; + Xbyak::Reg64 reg_src = abi_not_param1; + Xbyak::Reg64 reg_icb = rdx; + Xbyak::Reg64 reg_os = r11; + Xbyak::Reg64 reg_iw_start = r8; + + Xbyak::Reg64 reg_cur_os = rax; + Xbyak::Reg64 reg_cur_iw = r9; + Xbyak::Reg64 reg_cur_src = r10; + + int iw_, stride_w_; + int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_; + bool src_to_ws_; + size_t typesize_; + Vmm reg_zero; + Vmm reg_v; + + rtus_driver_t(int iw, int stride_w, int src_step_h, + int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize) + : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h) + , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb) + , src_to_ws_(src_to_ws), typesize_(typesize) + { + using namespace Xbyak; + vlen_ = cpu_isa_traits::vlen; + vlen_shift_ = cpu_isa_traits::vlen_shift; + if (typesize_ == 2) { + vlen_ /= 2; + vlen_shift_--; + } + + reg_zero = Vmm(0); + reg_v = Vmm(1); + + generate(); + } + + void loop_is() { + using namespace Xbyak; + + mov(reg_cur_src, reg_src); + mov(reg_cur_iw, reg_iw_start); + mov(reg_cur_os, reg_os); + + Label is_loop, skip_h_step; + L(is_loop); + + if (src_to_ws_) { + vmovups(reg_v, ptr[reg_cur_src]); + vmovups(ptr[reg_ws], reg_v); + } else { + vmovups(reg_v, ptr[reg_ws]); + vmovups(ptr[reg_cur_src], reg_v); + for (int w = 1; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + } + + add(reg_ws, vlen_); + + add(reg_cur_iw, stride_w_); + add(reg_cur_src, stride_w_ * vlen_); + + cmp(reg_cur_iw, iw_); + jl(skip_h_step); + /* for 1d convolution the loop over h should be skipped */ + if (src_step_icb_ == iw_) jmp(skip_h_step); + + if (src_to_ws_) { + add(reg_cur_src, (src_step_h_ - iw_) * vlen_); + } else { + Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */ + mov(reg_cur_src_fin, reg_cur_src); + add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_); + Label ih_loop; + L(ih_loop); + + for (int w = 0; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + + add(reg_cur_src, stride_w_ * vlen_); + cmp(reg_cur_src, reg_cur_src_fin); + jl(ih_loop); + } + xor_(reg_cur_iw, reg_cur_iw); + + L(skip_h_step); + + sub(reg_cur_os, vlen_); + jnz(is_loop); + + /* restore dst */ + sub(reg_ws, reg_os); + } + + void generate() { + using namespace Xbyak; + assert(isa == avx2 || isa == avx512_common + || isa == avx512_core || isa == avx512_mic); + +#if defined(_WIN32) + assert(reg_src == abi_not_param1 && abi_not_param1 == rdi); + push(rdi); +#endif + +#define READ_PARAM(what) \ + mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)]) + READ_PARAM(src); + READ_PARAM(icb); + READ_PARAM(os); + READ_PARAM(iw_start); + + assert(reg_ws == abi_param1); + READ_PARAM(ws); /* reg_ws should always be read the last */ +#undef READ_PARAM + + shl(reg_os, vlen_shift_); + + if (!src_to_ws_) + uni_vpxor(reg_zero, reg_zero, reg_zero); + + Label icb_loop; + L(icb_loop); + + loop_is(); + + add(reg_ws, ws_step_icb_ * vlen_); + add(reg_src, src_step_icb_ * vlen_); + + dec(reg_icb); + jnz(icb_loop, T_NEAR); + +#if defined(_WIN32) + pop(rdi); +#endif + + uni_vzeroupper(); + ret(); + this->ker_ = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +template +inline void init_rtus_driver(conv_t *self) { + const auto &conf = *self->pd(); + if (!conf.rtus_.reduce_src_) return; + + const auto &cd = *conf.desc(); + const int ndims = conf.ndims(); + const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0]; + const int stride_w = cd.strides[ndims - 3]; + + const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data; + const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md(); + + const int ih = ndims == 3 ? 1 : src_d.dims[2]; + const int iw = src_d.dims[ndims - 1]; + + const int src_step_h = stride_h * iw; + const int src_step_icb = ih * iw; + const int ws_step_icb = conf.jcp_.is; + const bool src_to_ws = !is_bwd_data; + const size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->pd()->desc())->data_type); + + self->rtus_driver_ = new rtus_driver_t(iw, stride_w, src_step_h, + src_step_icb, ws_step_icb, src_to_ws, typesize); +} + +inline int best_divider(int value, int min_divider, int max_divider, + bool find_max, int step = 1) +{ + max_divider = nstl::max(1, nstl::min(max_divider, value)); + min_divider = nstl::max(1, nstl::min(min_divider, max_divider)); + + auto loss_ratio = [](int total, int chunk) + { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); }; + + float min_loss = FLT_MAX; + int x_divider = max_divider; + for (int divider = max_divider; divider >= min_divider; divider -= step) { + const float loss = loss_ratio(value, divider); + if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) { + min_loss = loss; + x_divider = divider; + } + } + return x_divider; +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp new file mode 100644 index 0000000000..72fe3a8109 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp @@ -0,0 +1,1407 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "jit_uni_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace memory_tracking::names; + +using namespace Xbyak; +namespace barrier = simple_barrier; + +typedef float data_t; + +template +struct jit_bnorm_t: public jit_generator { + struct call_params_t { + // keep all sizes at 8 bytes -- jit code expects this + size_t N_ithr, N_nthr; + size_t coff_max, soff_max; + size_t mb_stride_Bc, spat_size, spat_size_loc; + size_t S_s, S_tail; + size_t is_cblk_tail; + data_t chan_size, eps, one; + const data_t *scale_shift; + const data_t *mean, *var; + const data_t *diff_scale_shift; + const data_t *src, *dst; + const data_t *diff_src, *diff_dst; + const data_t *rbuf1, *rbuf2; + const uint8_t *ws; + barrier::ctx_t *barrier; + }; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional3::type; + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx2) ? yword : zword; + + const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; + + const batch_normalization_pd_t *bdesc_; + bool is_spatial_thr_; + + void (*ker)(const call_params_t *); + void operator()(const call_params_t *p) { (*ker)(p); } + + Reg64 reg_param = abi_param1; + + Reg64 reg_scale_shift = rbx; + Reg64 reg_rbuf1 = abi_not_param1; + Reg64 reg_rbuf2 = rdx; + + Reg64 reg_mean = rbp; + Reg64 reg_var = reg_param; + Reg64 reg_diff_scale_shift = rax; + + Reg64 reg_coff = r8; + Reg64 reg_coff_max = r9; + Reg64 reg_soff = r10; + Reg64 reg_soff_max = r11; + Reg64 reg_ctr = r12; + Reg64 reg_roff = r13; + + Reg64 reg_mb_stride_Bc = r14; + + Reg64 reg_src = r15; + Reg64 reg_diff_src = reg_rbuf1; + Reg64 reg_dst = rsi; + Reg64 reg_diff_dst = reg_dst; + + Reg64 reg_tmp_off = reg_roff; + + // Reuse loop counters + Reg64 reg_bar = reg_coff; + Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff + Reg64 reg_tmp = reg_ctr; + + // Relu section + bool with_relu, with_relu_inf_only; + Vmm vzero; // is_fwd() ? vdiff_beta : vbeta + Reg64 reg_ws = reg_roff; + Label l_relu_mask_avx2; + Opmask kstore_mask = Opmask(1); + + // channel tail processing + Opmask ktail_mask = Opmask(2); + + size_t unroll_blocks; + size_t unroll_regs; + Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5); + Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6); + Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7); + Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8); + Vmm vone = Vmm(isa == avx512_common ? 24 : 9); + Vmm vmean = Vmm(isa == avx512_common ? 25 : 10); + Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11); + Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12); + Vmm veps = Vmm(isa == avx512_common ? 28 : 13); + Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14); + Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15); + + size_t t0_pf_offt; + size_t t1_pf_offt; + size_t spat_size; + size_t chan_data_offt; + + enum { + stack_off_N_nthr = 0, + stack_off_N_ithr = 8, + stack_off_src = 16, + stack_off_dst = 24, + stack_off_diff_src = 32, + stack_off_diff_dst = 40, + stack_off_diff_scale_shift = 48, + stack_off_ws = 56, + stack_off_barrier = 64, + stack_off_spat_size_loc = 72, + stack_off_s_s = 80, + stack_off_s_tail = 88, + stack_off_is_cblk_tail = 96, + stack_size_required = 104, + }; + + bool is_c_padded() const { + const memory_desc_wrapper data_d(bdesc_->src_md()); + return bdesc_->C() != data_d.padded_dims()[1]; + } + + void compute_static_strides() { + spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H(); + chan_data_offt = bdesc_->C() * sizeof(data_t); + + if (isa == avx512_mic) { + t0_pf_offt = 4096; + t1_pf_offt = 0; + } else { + t0_pf_offt = 0; + t1_pf_offt = 0; + } + } + + void load_common_params() { +# define PARAM_OFF(x) offsetof(call_params_t, x) + mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]); + if (bdesc_->is_bwd()) + mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]); + mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]); + mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]); + mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]); + shl(reg_coff_max, 2); + shl(reg_soff_max, 2); + shl(reg_mb_stride_Bc, 2); + + mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); + mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]); + + uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]); + uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]); + uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); + + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]); + mov(ptr[rsp + stack_off_N_nthr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]); + mov(ptr[rsp + stack_off_N_ithr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]); + mov(ptr[rsp + stack_off_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]); + mov(ptr[rsp + stack_off_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]); + mov(ptr[rsp + stack_off_diff_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]); + mov(ptr[rsp + stack_off_diff_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]); + mov(ptr[rsp + stack_off_ws], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]); + mov(ptr[rsp + stack_off_barrier], reg_tmp); + if (is_spatial_thr_) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]); + mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]); + mov(ptr[rsp + stack_off_s_s], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]); + mov(ptr[rsp + stack_off_s_tail], reg_tmp); + } + if (is_c_padded()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]); + mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp); + } + + if (bdesc_->is_fwd()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } else { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]); + mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } +# undef PARAM_OFF + } + + void prepare_tail_mask_avx512_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + const int mask = (1 << tail) - 1; + + Reg32 regw_tmp = reg_tmp.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + void prepare_tail_mask_avx2_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0, 0, 0, 0, 0, 0, 0, 0}; + + mov(reg_tmp, reinterpret_cast(&mask[8 - tail])); + vmovups(vtail_mask, ptr[reg_tmp]); + } + + void prepare_relu() { + with_relu = bdesc_->is_fwd() + ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu() + : bdesc_->fuse_bn_relu(); + with_relu_inf_only = with_relu && bdesc_->is_fwd() + && !(bdesc_->fuse_bn_relu() && bdesc_->is_training()); + + vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta; + if (with_relu) { + uni_vpxor(vzero, vzero, vzero); + if (!bdesc_->is_fwd() && isa == avx2) + prepare_l_relu_mask_avx2(); + } + } + + void prepare_l_relu_mask_avx2() { + Label l_mask_after; + jmp(l_mask_after); + align(32); + L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ + for (int i = 0; i < 8; ++i) dd(1< + void spat_loop(size_t len, size_t blocks, size_t regs, + init_t init, body_t body, fini_t fini) { + size_t factor = regs * blocks; + size_t loop_unroll = len / factor * factor; + size_t loop_tail = len - loop_unroll; + size_t num_active_regs = (len < regs) ? len : regs; + for (size_t i = 0; i < num_active_regs; i++) + init(i); + if (loop_unroll) { + if (is_spatial_thr_) { + mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); + add(reg_soff, ptr[rsp + stack_off_s_s]); + } else { + mov(reg_ctr, loop_unroll); + } + Label label; + L(label); { + for (size_t i = 0; i < factor; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + add(reg_soff, factor * vlen); + sub(reg_ctr, factor); + jnz(label); + } + if (is_spatial_thr_) { + add(reg_soff, ptr[rsp + stack_off_s_tail]); + } + } + + for (size_t i = 0; i < loop_tail; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + if (loop_tail) + add(reg_soff, loop_tail * vlen); + + for (size_t i = 0; i < num_active_regs; i++) + fini(i); + } + + void mean_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, + unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v0 = Vmm(base_reg * 2 + 0); + Vmm v1 = Vmm(base_reg * 2 + 1); + size_t offt = i * vlen; + uni_vmovups(v1, + vmmword[reg_src + reg_soff + offt]); + uni_vaddps(v0, v0, v1); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void var_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 3); + if (base_reg > 0) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(3 * base_reg); + Vmm vtmp0 = Vmm(3 * base_reg + 1); + Vmm vtmp1 = Vmm(3 * base_reg + 2); + size_t offt = i * vlen; + uni_vmovups(vtmp0, + vmmword[reg_src + reg_soff + offt]); + if (isa == sse42) { + movups(vtmp1, vmean); + subps(vtmp1, vtmp0); + } else { + vsubps(vtmp1, vmean, vtmp0); + } + uni_vfmadd231ps(v, vtmp1, vtmp1); + + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 3); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void compute_mean_variance() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf; + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + + xor_(reg_soff, reg_soff); + Label mean_spatial; + L(mean_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + mean_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + mean_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(mean_spatial); + } + + Label no_mean_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_mean_reduction); + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label mean_reduction_channels; + L(mean_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label mean_reduction_thrs; + L(mean_reduction_thrs); { + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0)); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(mean_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(mean_ptr(), Vmm(1)); + + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(mean_reduction_channels); + } + } + L(no_mean_reduction); + barrier(); + + xor_(reg_soff, reg_soff); + Label var_spatial; + L(var_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + var_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + var_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(var_spatial); + } + + Label no_var_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_var_reduction); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label var_reduction_channels; + L(var_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label var_reduction_thrs; + L(var_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(var_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(var_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(var_reduction_channels); + } + } + L(no_var_reduction); + barrier(); + } + + void forward_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + + if (bdesc_->use_scaleshift()) { + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vbeta, beta_ptr()); + } + + Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone; + Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar; + + if (isa == sse42) { + movups(vbuf, vscale); + divps(vbuf, vsqrtvar); + movups(vdiv, vbuf); + } else { + vdivps(vdiv, vscale, vsqrtvar); + } + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(base_reg); + size_t offt = i * vlen; + uni_vmovups(v, + vmmword[reg_src + reg_soff + offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + uni_vsubps(v, v, vmean); + if (bdesc_->use_scaleshift()) { + uni_vfmadd213ps(v, vgamma, vbeta); + } else { + uni_vmulps(v, v, vsqrtvar); + } + if (with_relu_inf_only) { + uni_vmaxps(v, v, vzero); + } else if (with_relu) { + if (isa == avx512_common) + fwd_process_relu_avx512_common(v, offt); + else + fwd_process_relu_avx2(v, offt, Vmm(3)); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_dst + reg_soff + offt], v); + } else { + uni_vmovups( + vmmword[reg_dst + reg_soff + offt], v); + } + }, + [](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_dst, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void forward() { + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_dst, ptr[rsp + stack_off_dst]); + mov(reg_ws, ptr[rsp + stack_off_ws]); + + xor_(reg_soff, reg_soff); + Label dst_spatial; + L(dst_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + forward_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + add(reg_dst, vlen / 2); + mov(reg_coff, vlen / 2); + + forward_channels(); + + sub(reg_src, vlen / 2); + sub(reg_dst, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jnz(dst_spatial); + } + } + + void backward_sh_channels() { + Label sh_channels; + L(sh_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]); + spat_loop(spat_size, 1, 1, + [=](size_t base_reg) { + if (base_reg > 0) { + for (int i = 0; i < 2; i++) { + Vmm v(base_reg * 5 + i); + uni_vpxor(v, v, v); + } + } + }, + [=](size_t base_reg, size_t i) { + Vmm o0 = Vmm(base_reg * 5 + 0); + Vmm o1 = Vmm(base_reg * 5 + 1); + Vmm t1 = Vmm(base_reg * 5 + 2); + Vmm t2 = Vmm(base_reg * 5 + 3); + Vmm t3 = Vmm(base_reg * 5 + 4); + size_t offt = i * vlen; + uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]); + uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(t2, offt); + else if (isa == avx2) + bwd_process_relu_avx2(t2, offt, t3); + else + assert(false); + } + uni_vsubps(t3, vmean, t1, t3); + if (isa == sse42) { + mulps(t3, t2); + subps(o0, t3); + } else { + vfnmadd231ps(o0, t3, t2); + } + uni_vaddps(o1, o1, t2); + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt + + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b0 = Vmm(0); + Vmm b1 = Vmm(1); + if (base_reg) { + uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0)); + uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1)); + } + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(sh_channels); + } + } + + void backward_diff_channels() { + Label diff_channels; + L(diff_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + if (bdesc_->use_scaleshift()) + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr()); + uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); + uni_vdivps(vdiff_beta, vdiff_beta, vchan_size); + uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size); + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v(base_reg * 2 + 0); + Vmm t(base_reg * 2 + 1); + Vmm t1(base_reg * 2 + 2); + size_t offt = i * vlen; + uni_vmovups(v, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(v, offt); + else if (isa == avx2) + bwd_process_relu_avx2(v, offt, t); + else + assert(false); + } + if (!bdesc_->use_global_stats()) { + uni_vsubps(v, v, vdiff_beta); + uni_vmovups(t, vmmword[reg_src + reg_soff + + offt]); + uni_vsubps(t, vmean, t, t1); + uni_vmulps(t, t, vdiff_gamma); + uni_vaddps(v, v, t); + } + uni_vmulps(v, v, vsqrtvar); + if (bdesc_->use_scaleshift()) { + uni_vmulps(v, v, vgamma); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_diff_src + reg_soff + offt], + v); + } else { + uni_vmovups( + vmmword[reg_diff_src + reg_soff + offt], + v); + } + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + + offt + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_diff_src, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(diff_channels); + } + } + + void backward() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf, sh_spatial; + + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + L(sh_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_sh_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_sh_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(sh_spatial); + } + + mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]); + + Label no_sh_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + Label sh_reduction_channels; + jne(no_sh_reduction, T_NEAR); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + L(sh_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + mov(reg_ctr, reg_nnthr); + Label sh_reduction_thrs; + L(sh_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]); + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(sh_reduction_thrs); + } + uni_vmulps(Vmm(0), Vmm(0), vsqrtvar); + uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0)); + uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(sh_reduction_channels); + } + } + L(no_sh_reduction); + barrier(); + + mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + Label diff_spatial; + L(diff_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_diff_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_diff_src, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_diff_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_diff_src, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(diff_spatial); + } + } + + jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) { + static_assert(isa == sse42 || isa == avx2 || isa == avx512_common + || isa == avx512_mic, "unsupported isa"); + + const int simd_w = isa == sse42 ? 8 : + cpu_isa_traits::vlen / sizeof(data_t); + is_spatial_thr_ = + bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t)); + + unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + + preamble(); + + if (isa == avx512_common) + prepare_tail_mask_avx512_common(); + else if (isa == avx2) + prepare_tail_mask_avx2_common(); + + compute_static_strides(); + sub(rsp, stack_size_required); + load_common_params(); + prepare_relu(); + + if (bdesc_->is_fwd()) { + if (!bdesc_->stats_is_src()) { + compute_mean_variance(); + } + forward(); + } else { + backward(); + } + add(rsp, stack_size_required); + postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); + } +}; + +template +struct uni_bnorm_driver_t: public c_compatible { + uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc) + : bdesc_(bdesc), ker_(bdesc_) + { + const int nthrs = mkldnn_get_max_threads(); + const dim_t C_PADDED = get_c_padded(bdesc_); + + size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED + * bdesc_->D() * bdesc_->H() * bdesc_->W(); + l3_size_ = get_cache_size(3, true) * nthrs / 2; + do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0); + } + + ~uni_bnorm_driver_t() {} + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const batch_normalization_pd_t *bdesc) { + int nthrs = mkldnn_get_max_threads(); + dim_t C_PADDED = get_c_padded(bdesc); + + int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED; + int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED; + int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs; + + scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz); + + if (mkldnn_thr_syncable()) { + int n_barriers = C_PADDED / simd_w; + scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers); + } + } + + void exec(int ithr, int nthr, const data_t *src, data_t *diff_src, + data_t *dst, const data_t *diff_dst, const data_t *scale_shift, + data_t *diff_scale_shift, const data_t *mean, const data_t *var, + const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) { + auto sbuf = scratchpad.get(key_bnorm_tmp_stats); + auto pbuf = scratchpad.get(key_bnorm_tmp_diff_ss); + auto rbuf = scratchpad.get(key_bnorm_reduction); + auto barriers = scratchpad.get(key_barrier); + + dim_t N = bdesc_->MB(); + dim_t C = bdesc_->C(); + dim_t C_PADDED = get_c_padded(bdesc_); + dim_t D = bdesc_->D(); + dim_t H = bdesc_->H(); + dim_t W = bdesc_->W(); + dim_t SP = D * H * W; + dim_t img_size = C_PADDED * D * H * W; + const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; + + typename jit_bnorm_t::call_params_t p; + + p.eps = bdesc_->desc()->batch_norm_epsilon; + p.one = 1.0f; + p.spat_size = D * H * W; + p.chan_size = 1.0f * N * p.spat_size; + + dim_t C_blks = C_PADDED / simd_w; + + int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0}; + dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0}; + + dim_t C_blks_per_iter{ 1 }; + int64_t iters{ 1 }; + if (do_blocking_) { + int num_tensors = bdesc_->is_fwd() ? 1 : 2; + size_t working_set_size + = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors; + bnorm_utils::cache_balance(working_set_size, C_blks, + C_blks_per_iter, iters); + } + + bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks, + SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, + S_ithr, S_nthr, S_s, S_e); + + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1)); + + p.N_ithr = SP_N_ithr; + p.N_nthr = SP_N_nthr; + + int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter; + int global_C_blk_s; + int global_barriers_per_iter = C_nthr; + + for (int64_t it = 0; it < iters; it++) { + if (it == iters - 1 && iters > 1) { + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + + // Update call parameters for JIT, last iteration + p.N_ithr = N_ithr * S_nthr + S_ithr; + p.N_nthr = N_nthr * S_nthr; + } + + global_C_blk_s = do_blocking_ ? + (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s : + C_blk_s; + + int C_blks_thr = C_blk_e - C_blk_s; + int N_thr = N_e - N_s; + + size_t coff_base = global_C_blk_s * simd_w; + size_t soff_base + = global_C_blk_s * p.spat_size * simd_w + N_s * img_size; + + p.spat_size_loc = S_e - S_s; + p.S_s = S_s * vlen; + p.S_tail = (p.spat_size - S_e) * vlen; + p.coff_max = C_blks_thr * simd_w; + p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base; + p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base; + p.scale_shift = scale_shift + coff_base; + p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_) + ? pbuf : diff_scale_shift) + coff_base; + + p.soff_max = N_thr * img_size; + p.src = src + soff_base; + p.dst = dst + soff_base; + p.diff_src = diff_src + soff_base; + p.diff_dst = diff_dst + soff_base; + p.ws = ws + soff_base / 8; + + p.mb_stride_Bc = img_size - p.coff_max * p.spat_size; + + // use SP_N_nthr which is the same as p.N_nthr except maybe for + // the last iteration. + p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr + + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w; + // rbuf1 and rbuf2 have to be disjoint + p.rbuf2 = p.rbuf1 + C_PADDED * nthr; + p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C; + + size_t iter_bariers + = do_blocking_ ? it * global_barriers_per_iter : 0; + p.barrier = barriers + C_ithr + iter_bariers; + if (p.soff_max != 0 && p.coff_max != 0) + ker_(&p); + } + } + + void init_barriers(const memory_tracking::grantor_t &scratchpad) { + auto barriers = scratchpad.get(key_barrier); + if (barriers) { + const int n_barriers = get_c_padded(bdesc_) / simd_w; + for (int i = 0; i < n_barriers; ++i) + barrier::ctx_init(&barriers[i]); + } + } + +private: + enum { + simd_w = isa == sse42 ? 8 : cpu_isa_traits::vlen / sizeof(data_t) + }; + + static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) { + return true + && !bdesc->stats_is_src() + && bdesc->desc()->prop_kind == prop_kind::forward_inference; + } + + static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc) + { + return false + || (bdesc->is_bwd() && !bdesc->use_scaleshift()) + || bdesc->desc()->prop_kind == prop_kind::backward_data; + } + + static dim_t get_c_padded(const batch_normalization_pd_t *bdesc) + { return bdesc->src_md()->padded_dims[1]; } + + const batch_normalization_pd_t *bdesc_; + bool do_blocking_; + size_t l3_size_; + + jit_bnorm_t ker_; +}; + +} + +using namespace data_type; +using namespace format_tag; +using namespace utils; + +/* fwd */ + +template +status_t jit_uni_batch_normalization_fwd_t::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + } + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t::init_scratchpad(scratchpad, this); + + return status::success; +} + +template +jit_uni_batch_normalization_fwd_t::jit_uni_batch_normalization_fwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } + +template +status_t jit_uni_batch_normalization_fwd_t::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + auto var = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, + scale_shift, nullptr, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template +jit_uni_batch_normalization_fwd_t::~jit_uni_batch_normalization_fwd_t() +{ delete bnorm_driver_; } + +/* bwd */ + +template +status_t jit_uni_batch_normalization_bwd_t::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c + : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c; + + bool ok = true + && mayiuse(isa) + && is_bwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + if (fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + /* TODO: extra checks required */ + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t::init_scratchpad(scratchpad, this); + + return status::success; +} + +template +jit_uni_batch_normalization_bwd_t::jit_uni_batch_normalization_bwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } + +template +status_t jit_uni_batch_normalization_bwd_t::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, + scale_shift, diff_scale_shift, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template +jit_uni_batch_normalization_bwd_t::~jit_uni_batch_normalization_bwd_t() +{ delete bnorm_driver_; } + +/* struct instantiation */ +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp new file mode 100644 index 0000000000..96410ec84e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp @@ -0,0 +1,100 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP +#define JIT_UNI_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_isa_traits.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { template struct uni_bnorm_driver_t; } + +template +struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_fwd_t); + + status_t init(); + }; + + typedef typename prec_traits::type data_t; + + jit_uni_batch_normalization_fwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t *bnorm_driver_; +}; + +template +struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_bwd_t); + + status_t init(); + }; + + typedef typename prec_traits::type data_t; + + jit_uni_batch_normalization_bwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_bwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t *bnorm_driver_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp new file mode 100644 index 0000000000..b7dc5f85c5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp @@ -0,0 +1,1302 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +template +void jit_uni_dw_conv_fwd_kernel_f32::load_src(int ur_ch_blocks, int ur_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + int b_off = ch*jcp.ch_block + i*4; + if (this->jcp.with_bias) + uni_vmovups(vmm_acc, + vmmword[reg_bias + b_off*sizeof(float)]); + else + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + + int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block + + ow*jcp.ch_block + i*4; + if (this->jcp.with_sum) + uni_vaddps(vmm_acc, vmm_acc, + vmmword[reg_output + o_off*sizeof(float)]); + } + } + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(iter_kw, reg_kw); + mov(aux1_reg_input, aux_reg_input); + mov(aux1_reg_kernel, aux_reg_kernel); + + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + i*4; + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux1_reg_kernel, ch_blk*sizeof(float)); + add(aux1_reg_input, ch_blk*dilate_w*sizeof(float)); + + dec(iter_kw); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_filter_unrolled( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int kw = 0; kw < jcp.kw; kw++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4; + + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + } + + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::apply_activation( + int ur_ch_blocks, int ur_w) { + if (this->jcp.with_eltwise) { + int repeats = isa == sse42 ? 2 : 1; + eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4); + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::store_dst( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4; + Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst); + } + } + } +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::loop_body(int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter_unrolled(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template +bool jit_uni_dw_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +template +status_t jit_uni_dw_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.oc / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +template struct jit_uni_dw_conv_fwd_kernel_f32; +template struct jit_uni_dw_conv_fwd_kernel_f32; +template struct jit_uni_dw_conv_fwd_kernel_f32; + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::load_ddst( + int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::apply_filter( + int ur_ch_blocks, int ur_str_w) { + int kw = jcp.kw; + int kh = jcp.kh; + int ow = jcp.ow; + int oh = jcp.oh; + + int ch_blk = jcp.ch_block; + int stride_h = jcp.stride_h; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(aux1_reg_ddst, aux_reg_ddst); + mov(aux1_reg_kernel, aux_reg_kernel); + + mov(iter_kw, reg_kw); + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*kh*kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int w = 0; w < ur_str_w; w++) { + int ddst_off = (ch*oh*ow + w)*ch_blk + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_ddst + + ddst_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + + add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float)); + sub(aux1_reg_ddst, ch_blk*sizeof(float)); + + sub(iter_kw, stride_w); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + + add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float)); + sub(aux_reg_ddst, ow*ch_blk*sizeof(float)); + + sub(iter_kh, stride_h); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::store_dsrc( + int ur_ch_blocks, int ur_str_w) { + int ch_blk = jcp.ch_block; + int iw = jcp.iw; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4; + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + + uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_data_kernel_f32::loop_body( + int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_str_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_str_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template +void jit_uni_dw_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); +} + +template +status_t jit_uni_dw_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1]; + + jcp.ih = diff_src_d.dims()[2]; + jcp.iw = diff_src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.ic / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template +void jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +template struct jit_uni_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_dw_conv_bwd_data_kernel_f32; +template struct jit_uni_dw_conv_bwd_data_kernel_f32; + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_filter() { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < jcp.kw; ++i) { + Vmm vmm_acc = get_acc_reg(r * jcp.kw + i); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (reg_set + i) * simd_w; + Vmm vmm_acc = get_acc_reg(reg_set + i); + uni_vmovups(vmm_acc, + vmmword[reg_tmp_filter + off_filter * sizeof(float)]); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vpxor(vmm_bias, vmm_bias, vmm_bias); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int iw_block = ow_block * jcp.stride_w; + const int right_border = jcp.iw - iw_block; + + const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); + + /* preamble count for number of cascaded LOAD + FMA operation */ + const int input_overlap = nstl::max(jcp.kw - l_pad, 0); + + /* LOAD initial input registers, then cascade LOADs and FMAs*/ + for (int r = 0; r < reg_repeats; ++r) { + for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { + int off_output = (i_ur * reg_repeats + r) * simd_w; + Vmm vmm_output = get_output_reg(r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + if (i_ur == 0) { + for (int c = 0; c < input_overlap; ++c) { + int off_input + = ((c - pad_offset) * reg_repeats + r) * simd_w; + Vmm vmm_input + = get_input_reg((c % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } else { + for (int c = 0; c < cascade_input; ++c) { + int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; + int off_input + = ((overlap + c - pad_offset) * reg_repeats + r) + * simd_w; + Vmm vmm_input = get_input_reg( + ((overlap + c) % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } + + for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { + int io_overlap = i_kw + (i_ur * jcp.stride_w); + + /* Don't apply FMAs that fall into the padded region */ + if (io_overlap - l_pad < 0 + || io_overlap - jcp.l_pad >= right_border) + continue; + + Vmm vmm_input = get_input_reg( + ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r); + Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r); + Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input; + if (isa == sse42) + uni_vmovups(vmm_aux, vmm_input); + uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output); + } + } + } +} + +template +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_step_unroll( + const int unroll_w) { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < unroll_w; ++i) { + Vmm vmm_bias = get_bias_reg(r); + int off_output = (i * reg_repeats + r) * simd_w; + if (isa == sse42) { + /* Need to support unaligned address loads for SSE42*/ + Vmm vmm_output = get_output_reg(1 + r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + uni_vaddps(vmm_bias, vmm_bias, vmm_output); + } else { + uni_vaddps(vmm_bias, vmm_bias, + vmmword[reg_tmp_output + off_output * sizeof(float)]); + } + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (i + reg_set) * simd_w; + Vmm vmm_acc = get_acc_reg(i + reg_set); + uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)], + vmm_acc); + } + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_loop( + const int block_size) { + Label oh_label; + Label ow_blk_label; + + const int unroll_w = nstl::min(block_size, jcp.ow); + const int unroll_w_trips = jcp.ow / unroll_w; + const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0; + + const int ch_offset = jcp.ch_block; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + L(oh_label); + { + + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + { + + compute_bias_step_unroll(unroll_w); + add(reg_tmp_output, unroll_w * ch_offset * sizeof(float)); + + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + if (tail_w > 0) { + compute_bias_step_unroll(tail_w); + add(reg_tmp_output, tail_w * ch_offset * sizeof(float)); + } + + inc(reg_oh); + cmp(reg_oh, reg_oh_worksize); + jl(oh_label, T_NEAR); + } +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_zero_filter() { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_zeroing_label; + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_FILTER); + test(reg_exec_flags, reg_exec_flags); + je(skip_zeroing_label); + + zero_filter(); + + mov(reg_tmp_filter, reg_filter_baddr); + mov(reg_kh, jcp.kh); + L(kh_loop_label); + { + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float)); + + L(skip_zeroing_label); +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_loop_label; + + cmp(reg_kh_count, 0); + je(skip_loop_label, T_NEAR); + + mov(reg_kh, reg_kh_count); + L(kh_loop_label); + { + load_filter(); + compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block); + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + Label kh_comeback_label; + mov(reg_kh, reg_kh_count); + L(kh_comeback_label); + { + sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_comeback_label, T_NEAR); + } + + L(skip_loop_label); +} + +template +inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ? + jcp.ih / jcp.stride_h - 1 : + jcp.oh - jcp.b_pad - 1; + const int ch_offset = jcp.ch_block; + const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + + Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label, + end_h_loop_label; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + mov(reg_kh_count, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + mov(reg_tmp_input, reg_input_baddr); + mov(reg_tmp_filter, reg_filter_baddr); + + L(h_loop_label); + { + + compute_h_step(unroll_w, l_pad, pad_offset, ow_block); + + add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float)); + + /* If within the top_pad region */ + if (jcp.t_pad > 0) { + /* Skip t_pad area if no longer in initial h_block */ + cmp(reg_oh, jcp.t_pad); + jg(skip_tpad_label, T_NEAR); + + cmp(reg_kh_count, jcp.kh); + jge(skip_tpad_label, T_NEAR); + + add(reg_kh_count, t_overlap_off); + sub(reg_tmp_filter, + t_overlap_off * jcp.kw * ch_offset * sizeof(float)); + + /* kernel has moved beyond padding (adjust for stride effects) */ + if (jcp.t_pad % jcp.stride_h != 0) { + int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; + add(reg_tmp_input, + inp_corr * jcp.iw * ch_offset * sizeof(float)); + } + jmp(tpad_loop_label, T_NEAR); + } + + L(skip_tpad_label); + + cmp(reg_oh, io_overlap); + jl(skip_bpad_label, T_NEAR); + sub(reg_kh_count, b_overlap_off); + + L(skip_bpad_label); + add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float)); + + L(tpad_loop_label); + + cmp(reg_oh, jcp.ih / jcp.stride_h); + jge(end_h_loop_label, T_NEAR); + + inc(reg_oh); + + cmp(reg_oh, reg_oh_worksize); + jl(h_loop_label, T_NEAR); + } + L(end_h_loop_label); +} + +template +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_block_unroll() { + + const int ch_offset = jcp.ch_block; + int ow = jcp.ow; + int pad_offset = 0; + int l_pad = jcp.l_pad; + + /* Calculate effective padding */ + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + + /* Is this strictly defined by: + * -code-size (?) + * -address size (?) */ + const int max_unroll_w = 30; + const int block_size = 15; + + int unroll_w_tail = 0; + int unroll_w = 0; + int unroll_w_trips = 0; + + if (jcp.ow > max_unroll_w) { + unroll_w = nstl::min(block_size, jcp.ow); + unroll_w_trips = ow / unroll_w; + /* calculate tail */ + unroll_w_tail = ow % unroll_w; + /* Perform some rebalancing if tail too small*/ + if ((unroll_w_tail == 0 && r_pad != 0) + || (r_pad > 0 && r_pad >= unroll_w_tail)) { + if (unroll_w_trips > 1) { + unroll_w_tail += unroll_w; + unroll_w_trips--; + } else { + /* Idealy, this case shouldn't happen */ + unroll_w_tail += (unroll_w - unroll_w / 2); + unroll_w = unroll_w / 2; + } + } + } else { + unroll_w = jcp.ow; + unroll_w_trips = nstl::max(1, ow / unroll_w); + } + if (jcp.with_bias) { + Label skip_load_bias; + mov(reg_bias_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]); + + zero_bias(); + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_BIAS); + test(reg_exec_flags, reg_exec_flags); + jne(skip_load_bias); + + load_bias(); + + L(skip_load_bias); + compute_bias_loop(block_size); + + store_bias(); + } + + /* Pass filter address, then offset for h_padding. */ + compute_zero_filter(); + mov(reg_kh_offset, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]); + add(reg_filter_baddr, reg_kh_offset); + + /* compute left padded block */ + if (l_pad) { + compute_h_loop(unroll_w, l_pad, 0, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + unroll_w_trips--; + pad_offset = l_pad; + l_pad = 0; + } + + /* compute middle block */ + Label ow_blk_label; + + /* Insert loop for 'ow' block when middle block needs to execute more + * than once */ + bool do_ow_blk_loop = unroll_w_trips > 1; + if (do_ow_blk_loop) { + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + } + if (unroll_w_trips > 0) { + compute_h_loop(unroll_w, l_pad, pad_offset, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + } + if (do_ow_blk_loop) { + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + /* compute right padded block */ + if (unroll_w_tail) { + compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail); + } +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::generate() { + preamble(); + + mov(reg_input_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]); + mov(reg_output_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]); + mov(reg_filter_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]); + + compute_ow_block_unroll(); + + this->postamble(); +} + +template +status_t jit_uni_dw_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads) { + if (!mayiuse(isa)) + return status::unimplemented; + + jcp.ngroups = diff_weights_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic); + + if (!jcp.is_depthwise) + return status::unimplemented; + + jcp.ch_block = isa == avx512_common ? 16 : 8; + + jcp.mb = src_d.dims()[0]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = diff_weights_d.dims()[3]; + jcp.kw = diff_weights_d.dims()[4]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.kw <= 3 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ch = jcp.ngroups / jcp.ch_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_hpad = (jcp.kh - 1 + 1) / 2; + const int max_wpad = (jcp.kw - 1 + 1) / 2; + const bool boundaries_ok = true && jcp.t_pad <= max_hpad + && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad + && jcp.r_pad <= max_wpad; + if (!boundaries_ok) + return status::unimplemented; + + balance(jcp, nthreads); + + return status::success; +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + /* Notes: if splitting thread work on 'mb', then a reduction has to take + * place. Hence, book a per-thread, local weights-buffer for the + * reduction */ + if (jcp.nthr_mb > 1) { + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * wei_size * (jcp.nthr_mb - 1)); + + if (jcp.with_bias) + scratchpad.book(key_conv_bia_reduction, + sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1)); + } +} + +template +void jit_uni_dw_conv_bwd_weights_kernel_f32::balance(jit_conv_conf_t &jcp, + int nthreads) { + jcp.nthr = nthreads; + jcp.nthr_g = jcp.nthr_mb = 1; + + /* Basic-Heuristics for parallel strategy: + * 1) Tries to parallel on the number of Groups (g) where tasks are + * independent. Otherwise, + * 2) Tries to split the work across g and MiniBatch (mb). + * Parallelizing on mb requires computing a reduction for weights. + * + * NOTE: because of 'task partitioning' scheme, there will be unbalanced + * per-thread load when the number of threads is high (e.g. > 16). + */ + jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr); + jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb); + + jcp.nthr = jcp.nthr_g * jcp.nthr_mb; +} + +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp new file mode 100644 index 0000000000..9c08fc4a09 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp @@ -0,0 +1,253 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP +#define JIT_UNI_DW_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) + + jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_uni_dw_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + const Xbyak::AddressFrame &vmmword = (isa == sse42) + ? xword : (isa == avx2) ? yword : zword; + const int vlen = cpu_isa_traits::vlen; + + // dw convolution + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t aux1_reg_input = r10; + reg64_t reg_kernel = r11; + reg64_t aux_reg_kernel = r12; + reg64_t aux1_reg_kernel = r13; + reg64_t reg_output = r14; + reg64_t reg_bias = r15; + reg64_t reg_kh = rax; + reg64_t reg_kw = rbx; + reg64_t iter_kh = rdx; + reg64_t iter_kw = rsi; + reg64_t reg_ur_w = rbp; + reg64_t reg_ch_blocks = aux1_reg_input; + reg64_t imm_addr64 = aux1_reg_input; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + inline void load_src(int ur_ch_blocks, int ur_w); + inline void apply_filter(int ur_ch_blocks, int ur_w); + inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w); + inline void apply_activation(int ur_ch_blocks, int ur_w); + inline void store_dst(int ur_ch_blocks, int ur_w); + inline void loop_body(int ur_ch_blocks); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; + + void generate(); +}; + +template +struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) + + jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t aux1_reg_ddst = abi_not_param1; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t aux1_reg_kernel = rbp; + reg64_t reg_dsrc = rsi; + + reg64_t reg_ur_str_w = r9; + reg64_t reg_ch_blocks = rbx; + + reg64_t iter_kh = r11; + reg64_t iter_kw = r12; + reg64_t reg_kh = r13; + reg64_t reg_kw = r14; + + inline void loop_body(int ur_ch_blocks); + inline void load_ddst(int ur_ch_blocks, int ur_str_w); + inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void store_dsrc(int ur_ch_blocks, int ur_str_w); + + void generate(); +}; + +template +struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) + + jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + static void balance(jit_conv_conf_t &jcp, int nthreads); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_dw_conv_call_s *); + +private: + using Vmm = typename utils::conditional3::type; + using reg64_t = const Xbyak::Reg64; + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int reg_repeats = (isa == sse42) ? 2 : 1; + + const Xbyak::AddressFrame &vmmword + = (isa == sse42) ? xword : (isa == avx2) ? yword : zword; + + /* XXX: offset between input and accummulators is 3, therefore, assume 'kw' + * is no larger than 3*/ + inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); } + inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); } + inline Vmm get_aux_reg() { return Vmm(0); } + + reg64_t reg_tmp_input = r9; + reg64_t reg_tmp_output = r10; + reg64_t reg_tmp_filter = r13; + reg64_t reg_kh_offset = rax; + + /* parameter passed by driver into kernel */ + Xbyak::Reg8 reg_exec_flags = bl; + + reg64_t reg_oh_worksize = r14; + reg64_t reg_oh = rax; + + reg64_t iter_ow_blk = r11; + + reg64_t reg_kh = rsi; + reg64_t reg_kh_count = rdx; + + /* Base addresses for convolution parameters. */ + reg64_t reg_input_baddr = r15; + reg64_t reg_output_baddr = r12; + reg64_t reg_filter_baddr = abi_not_param1; + reg64_t reg_bias_baddr = r13; + + /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs + */ + inline void compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ + inline void compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block); + inline void compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* Write 'width' micro-kernel JITs; depending on the padding and convolution + * size, write a micro-kernel for the left ow-block, middle ow-block(s), and + * right ow-block.*/ + inline void compute_ow_block_unroll(); + + inline void compute_zero_filter(); + inline void load_filter(); + inline void zero_filter(); + inline void load_bias(); + inline void zero_bias(); + inline void compute_bias_step_unroll(const int unroll_w); + inline void compute_bias_loop(const int block_size); + inline void store_filter(); + inline void store_bias(); + + void generate(); +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp new file mode 100644 index 0000000000..58449601a3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp @@ -0,0 +1,427 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "jit_uni_dw_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template +void _jit_uni_dw_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + if (pd()->wants_padded_bias()) { + auto padded_bias = this->scratchpad(ctx).template get( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + int dil_h = jcp.dilate_h + 1; + int dil_w = jcp.dilate_w + 1; + int str_h = jcp.stride_h; + int str_w = jcp.stride_w; + + auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh, + int kh_padding, int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w)); + const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w + + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw; + + const int iw = nstl::max((ow*str_w - jcp.l_pad + + div_up(i_l_overflow, dil_w)*dil_w), 0); + const int kw = div_up(i_l_overflow, dil_w); + + const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w) + - div_up(i_r_overflow, dil_w); + + par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)]; + + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)]; + if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)]; + + par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); + par_conv.kw_padding = (size_t)nstl::max(0, kw_padding); + + par_conv.ur_w = (size_t)ur_w_step; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.oh, + [&](int n, int chb, int oh) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h)); + const int i_b_overflow = nstl::max(jcp.ih, + (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih; + + const int ih = nstl::max((int)(oh*str_h - jcp.t_pad + + div_up(i_t_overflow, dil_h)*dil_h), 0); + const int kh = div_up(i_t_overflow, dil_h); + const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) + - div_up(i_b_overflow, dil_h); + + // left border + int ow = 0; + int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow); + int ur_w_step = 1; + for (; ow < l_border; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1) + / jcp.stride_w - ow + 1; + if (ur_w_step > 0) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + ow += ur_w_step; + } + + // right border + ur_w_step = 1; + for (; ow < jcp.ow; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template struct _jit_uni_dw_convolution_fwd_t; +template struct _jit_uni_dw_convolution_fwd_t; +template struct _jit_uni_dw_convolution_fwd_t; + +template +void _jit_uni_dw_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, + int i_t_overflow, int i_b_overflow, int stride_off_h, + int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); + const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) + - jcp.r_pad)); + + int ow = iw + jcp.l_pad - i_r_overflow; + int stride_off_w = ow % jcp.stride_w; + ow /= jcp.stride_w; + + par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)]; + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow + + stride_off_h, i_r_overflow + stride_off_w)]; + + par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow + - stride_off_h); + par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow + - stride_off_w); + + par_conv.ur_str_w = ur_str_w; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.ih, + [&](int n, int chb, int ih) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih + - jcp.t_pad)); + const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1 + - (jcp.ih - 1 - ih) - jcp.b_pad)); + + int oh = ih + jcp.t_pad - i_b_overflow; + int stride_off_h = oh % jcp.stride_h; + oh /= jcp.stride_h; + + for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { + // left border + int iw = i_str_w; + int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); + int ur_str_w = 1; + for (; iw < l_border; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw) + / jcp.stride_w, jcp.iw); + if (ur_str_w > 0) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + iw += ur_str_w * jcp.stride_w; + } + + // right border + ur_str_w = 1; + for (; iw < jcp.iw; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + } + }); +} + +template struct _jit_uni_dw_convolution_bwd_data_t; +template struct _jit_uni_dw_convolution_bwd_data_t; +template struct _jit_uni_dw_convolution_bwd_data_t; + +template +_jit_uni_dw_convolution_bwd_weights_t:: +_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr) +{ + kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32(pd()->jcp_); + if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction()) + acc_ker_ = new cpu_accumulator_1d_t(); +} + +template +void _jit_uni_dw_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto diff_wei_reduction_buf = + scratchpad(ctx).template get(key_conv_wei_reduction); + auto diff_bia_reduction_buf = + scratchpad(ctx).template get(key_conv_bia_reduction); + + const auto &jcp = kernel_->jcp; + + /* Used when executing a parallel reduction */ + simple_barrier::ctx_t reduction_bctx; + simple_barrier::ctx_init(&reduction_bctx); + + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; + + const int ch_block = jcp.ch_block; + + auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params, + const int batch, const int group, const int oh_start, + const int work_size, const unsigned char exec_flag, + const size_t kh_padding, const size_t filter_off) { + const int tpad_underflow_off = jcp.t_pad - filter_off; + + conv_params->exec_flags = exec_flag; + conv_params->kh_count = jcp.kh - kh_padding; + + const int oh_s = oh_start; + const int oh_e = oh_start + work_size; + const int ih_s = oh_s * jcp.stride_h; + + conv_params->filter_pad_off + = filter_off * jcp.kw * ch_block * sizeof(float); + conv_params->oh_index = oh_s; + conv_params->oh_count = oh_e; + + size_t diff_dst_off + = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh + + oh_start) + * jcp.ow; + + size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih + + ih_s - tpad_underflow_off) * jcp.iw; + + conv_params->output = &diff_dst[diff_dst_off * ch_block]; + conv_params->input = &src[src_off * ch_block]; + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + + auto conv_params = jit_dw_conv_call_s(); + const int h_block_size = 15; + + /* assign iteration space to thread */ + const int ithr_g = ithr % jcp.nthr_g; + const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; + + /* split dimensions */ + int g_start{ 0 }, g_end{ 0 }; + balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end); + + int mb_start{ 0 }, mb_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); + + auto diff_wei = ithr_mb == 0 + ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size; + auto diff_bia = ithr_mb == 0 + ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size; + + for (int g = g_start; g < g_end; ++g) { + unsigned char zero_filter_flag = FLAG_ZERO_FILTER; + unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; + + size_t diff_wei_off = g * jcp.kh * jcp.kw; + conv_params.filter = &diff_wei[diff_wei_off * ch_block]; + + if (jcp.with_bias) + conv_params.bias = &diff_bia[g * ch_block]; + + for (int mb = mb_start; mb < mb_end; ++mb) { + int oh = 0; + while (oh < jcp.oh) { + const int h_work = nstl::min(h_block_size, jcp.oh - oh); + auto kh_t_padding = nstl::max(0, jcp.t_pad - oh); + auto kh_b_padding + = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ? + jcp.b_pad - (h_work - 1) : + 0; + + set_kernel_params(&conv_params, mb, g, oh, h_work, + zero_filter_flag | zero_bias_flag, + kh_t_padding + kh_b_padding, kh_t_padding); + kernel_->jit_ker(&conv_params); + + zero_bias_flag &= ~FLAG_ZERO_BIAS; + zero_filter_flag &= ~FLAG_ZERO_FILTER; + oh += h_work; + } + } + } + + if (do_parallel_reduction() && jcp.nthr_mb > 1) { + size_t reduct_start{ 0 }, reduct_end{ 0 }; + balance211(wei_size, nthr, ithr, reduct_start, reduct_end); + + const int acc_size = reduct_end - reduct_start; + const size_t reduct_off = reduct_start; + auto *acc_data = diff_weights + reduct_off; + + simple_barrier::barrier(&reduction_bctx, nthr); + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + auto *src_data = diff_wei_reduction_buf + + (thr_mb - 1) * wei_size + reduct_off; + acc_ker_->accumulate(acc_data, src_data, acc_size); + } + } + }); + + if (jcp.nthr_mb <= 1) return; + + /* Apply single-threaded 'mb' reduction */ + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + size_t mb_accum_offset = (thr_mb - 1) * wei_size; + size_t b_accum_offset = (thr_mb - 1) * bias_size; + + for (int g = 0; g < jcp.nb_ch; ++g) { + /* Reduction on Bias */ + if (jcp.with_bias) { + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + size_t bias_offset = g * ch_block + g_block; + diff_bias[bias_offset] += diff_bia_reduction_buf[ + b_accum_offset + bias_offset]; + } + } + + if (do_parallel_reduction()) continue; + + for (int kh = 0; kh < jcp.kh; ++kh) + for (int kw = 0; kw < jcp.kw; ++kw) + { + size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw; + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + const size_t off = wei_offset * ch_block + g_block; + diff_weights[off] += + diff_wei_reduction_buf[mb_accum_offset + off]; + } + } + } + } +} + +template struct _jit_uni_dw_convolution_bwd_weights_t; +template struct _jit_uni_dw_convolution_bwd_weights_t; +template struct _jit_uni_dw_convolution_bwd_weights_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp new file mode 100644 index 0000000000..ca53749ec2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP +#define CPU_JIT_UNI_DW_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_fwd_kernel_f32::init_conf( + jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32(pd()->jcp_); } + + ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_fwd_kernel_f32 *kernel_; +}; + +using jit_avx512_common_dw_convolution_fwd_t = + _jit_uni_dw_convolution_fwd_t; +using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; +using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; + +template +struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_bwd_data_kernel_f32:: + init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_data_kernel_f32 *kernel_; +}; + +using jit_avx512_common_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; +using jit_avx2_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; +using jit_sse42_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t; + +template +struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const int max_threads = mkldnn_in_parallel() + ? 1 : mkldnn_get_max_threads(); + + status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md(), max_threads); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); + ~_jit_uni_dw_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + }; + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + bool do_parallel_reduction() const { return false; } + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_weights_kernel_f32 *kernel_; + cpu_accumulator_1d_t *acc_ker_; +}; + +using jit_avx512_common_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; +using jit_avx2_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; +using jit_sse42_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp new file mode 100644 index 0000000000..2af6435871 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp @@ -0,0 +1,1142 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_eltwise.hpp" + +#define GET_OFF(field) offsetof(jit_args, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template +void jit_uni_eltwise_injector_f32::injector_preamble(size_t start_idx, + size_t end_idx) { + preserved_vecs_count = 0; + vecs_to_preserve = (size_t)aux_vecs_count(alg_); + start_idx_tail = start_idx; + + // For sse42 mask register has to be Xmm(0) + if (isa == sse42 && vecs_to_preserve > 0) { + size_t idx = 0; + assert(idx < start_idx); + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) { + if (preserved_vecs_count >= vecs_to_preserve) break; + if (start_idx <= idx && idx < end_idx) continue; + + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; + for (size_t i = 0; i < preserved_vecs_count_tail; i++) { + preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++; + } + + assert(preserved_vecs_count == vecs_to_preserve); + + if (save_state_) { + h->push(p_table); + + if (preserved_vecs_count) + h->sub(h->rsp, preserved_vecs_count * vlen); + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[i])); + + load_table_addr(); + } + + assign_regs(); +} + +template +void jit_uni_eltwise_injector_f32::injector_preamble_tail(size_t start_idx) +{ + size_t tail_vecs_to_preserve = start_idx_tail - start_idx; + if (tail_vecs_to_preserve == 0) return; + + const int idx_off = vecs_to_preserve - tail_vecs_to_preserve; + + if (save_state_) { + if (idx_off) + h->add(h->rsp, idx_off * vlen); + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), + h->ptr[h->rsp + i * vlen]); + } + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; + + if (save_state_) { + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[idx_off + i])); + + if (idx_off) + h->sub(h->rsp, idx_off * vlen); + } + + assign_regs(); +} + +template +void jit_uni_eltwise_injector_f32::injector_postamble() { + if (!save_state_) return; + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[i]), + h->ptr[h->rsp + i * vlen]); + + if (preserved_vecs_count) + h->add(h->rsp, preserved_vecs_count * vlen); + + h->pop(p_table); +} + +template +void jit_uni_eltwise_injector_f32::assign_regs() { + vmm_mask = Vmm(preserved_vec_idxs[0]); + vmm_aux0 = Vmm(preserved_vec_idxs[0]); + vmm_aux1 = Vmm(preserved_vec_idxs[1]); + vmm_aux2 = Vmm(preserved_vec_idxs[2]); + vmm_aux3 = Vmm(preserved_vec_idxs[3]); + vmm_aux4 = Vmm(preserved_vec_idxs[4]); +} + +template +void jit_uni_eltwise_injector_f32::exp_compute_vector(const Vmm &vmm_src) { + h->uni_vminps(vmm_src, vmm_src, table_val(10)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(11)); + h->uni_vmovups(vmm_aux0, vmm_src); + //calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux1, vmm_aux1); + + h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux1, vmm_src, _op_floor); + } + + //keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx + + //x = x - fx * ln2 + h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3)); + + // compute 2^n + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx + + // y = p5 + h->uni_vmovups(vmm_src, table_val(9)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q) + // y = y * 2^n + h->uni_vmulps(vmm_src, vmm_src, vmm_aux1); +} + +template +void jit_uni_eltwise_injector_f32::relu_compute_vector(const Vmm &vmm_src) +{ + const int alpha_off = 0, zero_off = 1; + + h->uni_vmovups(vmm_aux1, vmm_src); + if (isa == sse42) { + h->movups(vmm_mask, vmm_src); + h->mulps(vmm_src, table_val(alpha_off)); + h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us); + h->blendvps(vmm_src, vmm_aux1); + } else if (isa == avx2) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off)); + h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } else if (isa == avx512_common) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } +} + +template +void jit_uni_eltwise_injector_f32::relu_zero_ns_compute_vector( + const Vmm &vmm_src) { + const int zero_off = 1; + h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off)); +} + +template +void jit_uni_eltwise_injector_f32::elu_compute_vector(const Vmm &vmm_src) { + const int alpha_off = 23, zero_off = 24; + + // compute exponent + h->uni_vmovups(vmm_aux2, vmm_src); + exp_compute_vector(vmm_src); + + // alpha * (exp(x) - 1) + h->uni_vsubps(vmm_src, vmm_src, table_val(0)); + h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off)); + + // combine with mask + if (isa == sse42) { + h->pxor(vmm_mask, vmm_mask); + h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os); + h->blendvps(vmm_src, vmm_aux2); + } else if (isa == avx2) { + h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask); + } else if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2); + } +} + +template +void jit_uni_eltwise_injector_f32::tanh_compute_vector(const Vmm &vmm_src) +{ + // # comes from Taylor expansion error bound + // > linear_sat_point = single(sqrt(3) * 1b-12); + // # comes from the exp formula cancellation + // > exp_bound_point = (single(log(3)/2)); + // # comes from rounding accuracy in float + // > one_sat_point = round(atanh(1 - 1b-25), single, RU); + // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |], + // [linear_sat_point, exp_bound_point], relative, floating); + // > err_bound = D(sup(supnorm(P, tanh(x), + // [linear_sat_point, exp_bound_point], relative, theta))); + // 0x1.fffd6f00b9539p-25 + // > P; + // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 * + // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5 + // + x^0x1p1 * 0x1.09fa1p-6)))) + + // register mapping + // vmm_src contains input + // vmm_aux0 contains mask of currently valid results. + // 1 is need computation, 0 is already computed + // vmm_aux1 contains current output + // vmm_aux2, vmm_aux3 contains auxiliary values + // vmm_aux4 contains the original sign of inputs + + Label end_tanh_label; + + auto test_exit =[&](Xbyak::Address threshold){ + // is not necessary for >AVX, but should not matter on perf + h->uni_vmovups(vmm_aux0, vmm_src); + if (isa == avx512_common){ + h->vcmpps(k_mask, vmm_aux0, threshold, 0x5); + h->kortestw(k_mask, k_mask); + } else { + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold); + h->uni_vtestps(vmm_aux0, vmm_aux0); + } + h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR); + }; + + auto blend_results=[&](Vmm vmm_partial_res){ + if (isa == avx512_common) + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res); + else + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0); + }; + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + // mov is not necessary for >AVX, but should not matter for performance + h->uni_vmovups(vmm_aux4, vmm_src); + h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12)); + h->uni_vandps(vmm_src, vmm_src, table_val(17)); + + // if x < linear_sat_point for all inputs, we just return the input + h->uni_vmovups(vmm_aux1, vmm_src); + test_exit(table_val(13)); + + // if one of the mask is one, we have to compute an better approx + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2); + h->uni_vmovups(vmm_aux3, table_val(22)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18)); + h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src); + + // we blend only the result that need update + blend_results(vmm_aux3); + + // if x < exp_bound_point, we go to return point + test_exit(table_val(14)); + + // if not we use a better approx 1 - 2 / (1 + exp(2x)) + // compute 2x + h->uni_vmovups(vmm_aux3, vmm_src); + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3); + + // Compute exp(2x) + // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them + // vmm_src is not more read afterwards, so we do not have to save it + auto stack_size = 3 * vlen + (isa == avx512_common) * 4; + h->sub(h->rsp, stack_size); + h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0); + h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1); + h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src); + if (isa == avx512_common) + h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask); + + exp_compute_vector(vmm_aux3); + + h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]); + h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]); + h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]); + if (isa == avx512_common) + h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]); + h->add(h->rsp, stack_size); + + // 1 + exp(2x) + h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0)); + + // 1 - 2 / (1 + exp(2x)) + h->uni_vmovups(vmm_aux2, table_val(16)); + h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3); + h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0)); + + // we blend only the result that need update + blend_results(vmm_aux2); + + // finally, we saturate to 1 if needed + // TODO: maybe move that up if most inputs saturate in practice + if (isa == avx512_common) + h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5); + else { + h->uni_vmovups(vmm_aux0, vmm_src); + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15)); + } + h->uni_vmovups(vmm_aux2, table_val(0)); + blend_results(vmm_aux2); + + h->L(end_tanh_label); + { + // we apply the sign of x to the result and we are done + h->uni_vmovups(vmm_src, vmm_aux1); + h->uni_vpxor(vmm_src, vmm_src, vmm_aux4); + } +} + +template +void jit_uni_eltwise_injector_f32::square_compute_vector( + const Vmm &vmm_src) { + h->uni_vmulps(vmm_src, vmm_src, vmm_src); +} + +template +void jit_uni_eltwise_injector_f32::abs_compute_vector(const Vmm &vmm_src) { + // compute abs(x) = _mm_and_ps(x, 01111..111)); + h->uni_vandps(vmm_src, vmm_src, table_val(0)); +} + +template +void jit_uni_eltwise_injector_f32::sqrt_compute_vector(const Vmm &vmm_src) +{ + if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } else { + h->uni_vmovups(vmm_mask, vmm_src); + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0)); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } +} + +template +void jit_uni_eltwise_injector_f32::linear_compute_vector( + const Vmm &vmm_src) { + // compute x = alpha * x + beta; + h->uni_vmovups(vmm_aux0, table_val(0)); + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1)); +} + +template +void jit_uni_eltwise_injector_f32::bounded_relu_compute_vector( + const Vmm &vmm_src) { + // compute bounded relu */ + h->uni_vmaxps(vmm_src, vmm_src, table_val(1)); + h->uni_vminps(vmm_src, vmm_src, table_val(0)); +} + +template +void jit_uni_eltwise_injector_f32::soft_relu_compute_vector( + const Vmm &vmm_src) { + // duplicate src + h->uni_vmovups(vmm_aux2, vmm_src); + + h->uni_vminps(vmm_src, vmm_src, table_val(24)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(25)); + h->uni_vmovups(vmm_aux1, vmm_src); + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux0, vmm_aux0); + + h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); + } + + // keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx + // calculation fx * ln2 + h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3)); + // x = x - fx * ln2 + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); + // y = p5 + h->uni_vmovups(vmm_aux3, table_val(22)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17)); + + // compute 2^(-n) + if (isa == avx512_common) { + h->vmulps(vmm_aux1, vmm_src, table_val(23)); + h->vcvtps2dq(vmm_aux1, vmm_aux1); + } else { + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23)); + } + + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx + // calculate ln(1 + y) + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); + // x = y; y is free; keep x for further computations + h->uni_vmovups(vmm_src, vmm_aux3); + // frexp() + h->uni_vpsrld(vmm_src, vmm_src, 23); + h->uni_vcvtdq2ps(vmm_src, vmm_src); + // got n. where n is x = 2^n * y. y = 0.5 .. 1 + h->uni_vsubps(vmm_src, vmm_src, table_val(5)); + + h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6)); + // got y. (mantisa) 0.5 < y < 1 + h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7)); + // y = y - 1 + h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0)); + // y = p8 + h->uni_vmovups(vmm_aux1, table_val(16)); + // y = y * x + p7 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15)); + // y = y * x + p6 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14)); + // y = y * x + p5 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9)); + // y = y * x + p0 ; p0 = 0 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8)); + //calculate ln(2) * n + h->uni_vmulps(vmm_src, vmm_src, table_val(3)); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0); + + // get vmm_mask = src > max logf + h->uni_vmovups(vmm_mask, vmm_aux2); + if (isa == avx512_common) { + // y = (x < max log f) ? soft_relu(x) : x + h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us); + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2); + } else { + // y = (x < max log f) ? soft_relu(x) : x + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24)); + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask); + } + + h->uni_vmovups(vmm_src, vmm_aux1); +} + +template +void jit_uni_eltwise_injector_f32::logistic_compute_vector( + const Vmm &vmm_src) { + // we store the original sign and make x negative + // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required + // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it. + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12)); + h->uni_vorps(vmm_src, vmm_src, table_val(12)); + + exp_compute_vector(vmm_src); + // dup exp(x) + h->uni_vmovups(vmm_aux1, vmm_src); + // (exp(x) + 1) + h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0)); + // y = exp(x) / (exp(x) + 1) + h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); + + // Now we have to apply the "symmetry" based on original sign + h->uni_vmovups(vmm_aux3, table_val(0)); + h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); + if (isa == avx512_common) { + h->vptestmd(k_mask, vmm_aux2, vmm_aux2); + h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src); + } else { + h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2 + h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0); + } + h->uni_vmovups(vmm_src, vmm_aux3); +} + +template +void jit_uni_eltwise_injector_f32::relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::elu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + // exp(x) polynom + 0x3f800001, // [5] p0 = 1.0000001f + 0x3efffe85, // [6] p2 = 0.4999887f + 0x3e2aaa3e, // [7] p3 = 0.16666505f + 0x3d2bb1b1, // [8] p4 = 0.041917507f + 0x3c091ec1, // [9] p5 = 0.008369149f + 0x42b0c0a5, //[10] max logf = 88.3762589f + 0xc1766666, //[11] min logf = -14.5f + // tanh(x) constants, + 0x80000000, //[12] mask to extract sign + 0x39ddb3d7, //[13] arg below which tanh(x) = x + 0x3f0c9f54, //[14] arg below which pol approx is valid + 0x41102cb4, //[15] arg after which tanh(x) = 1 + 0xc0000000, //[16] -2.0f + 0x7fffffff, //[17] mask to make positive + // tanh pol approx + 0x3f7fffff, //[18] p0 + 0xbeaaa9cf, //[19] p1 + 0x3e085f1f, //[20] p2 + 0xbd572bda, //[21] p3 + 0x3c84fd08, //[22] p4 + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]); + } + + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::soft_relu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + 0x42fc0000, // [5] 126 + 0x807fffff, // [6] and with (to get 0.5 * mantissa) + 0x3f000000, // [7] or with (to get 0.5 * mantissa) + // ln(1 + x) polynomial + 0xb2b4637d, // [8] p0 = 0.0000000244f + 0x3f7fff8e, // [9] p1 = 0.9999976971f + 0xbf001759, //[10] p2 = -0.5002478215f + 0x3ea70608, //[11] p3 = 0.3272714505f + 0xbea3d7bf, //[12] p4 = -0.3153830071f + 0xbe361d04, //[13] p5 = -0.1701777461f + 0xbfa8f1e6, //[14] p6 = -1.3254635147f + 0xbfe1e812, //[15] p7 = -1.7971917960f + 0xbfc4d30e, //[16] p8 = -1.5652673123f + // exp(x) polynomial + 0x3f800001, //[17] p0 = 1.0000001f + 0x3f800000, //[18] p1 = 1.0f + 0x3efffe85, //[19] p2 = 0.4999887f + 0x3e2aaa3e, //[20] p3 = 0.16666505f + 0x3d2bb1b1, //[21] p4 = 0.041917507f + 0x3c091ec1, //[22] p5 = 0.008369149f + 0xbf800000, //[23] is required for sign changing + 0x42b0c0a5, //[24] max logf = 88.3762589f + 0xc1766666 //[25] min logf = -14.5f + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) { + h->dd(cvals[i]); + } + } +} + +template +void jit_uni_eltwise_injector_f32::abs_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff); +} + +template +void jit_uni_eltwise_injector_f32::sqrt_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +void jit_uni_eltwise_injector_f32::linear_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_)); +} + +template +void jit_uni_eltwise_injector_f32::bounded_relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template +int jit_uni_eltwise_injector_f32::aux_vecs_count(alg_kind_t alg_) { + switch (alg_) { + case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2; + case alg_kind::eltwise_elu: return 4; + case alg_kind::eltwise_tanh: return 5; + case alg_kind::eltwise_square: return 0; + case alg_kind::eltwise_abs: return 0; + case alg_kind::eltwise_sqrt: return 2; + case alg_kind::eltwise_linear: return 1; + case alg_kind::eltwise_bounded_relu: return 0; + case alg_kind::eltwise_soft_relu: return 4; + case alg_kind::eltwise_logistic: return 4; + default: assert(!"unsupported eltwise algorithm"); + } + + return 0; +} + +template +void jit_uni_eltwise_injector_f32::compute_body(size_t start_idx, + size_t end_idx) { + using namespace alg_kind; + for (size_t idx = start_idx; idx < end_idx; idx++) { + switch (alg_) { + case eltwise_relu: + if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx)); + else relu_compute_vector(Vmm(idx)); + break; + case eltwise_elu: elu_compute_vector(Vmm(idx)); break; + case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break; + case eltwise_square: square_compute_vector(Vmm(idx)); break; + case eltwise_abs: abs_compute_vector(Vmm(idx)); break; + case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break; + case eltwise_linear: linear_compute_vector(Vmm(idx)); break; + case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break; + case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break; + case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template +void jit_uni_eltwise_injector_f32::compute_vector_range(size_t start_idx, + size_t end_idx) { + assert(start_idx < end_idx && end_idx <= vecs_count); + + injector_preamble(start_idx, end_idx); + compute_body(start_idx_tail, end_idx); + injector_preamble_tail(start_idx); + compute_body(start_idx, start_idx_tail); + injector_postamble(); +} + +template +void jit_uni_eltwise_injector_f32::prepare_table(bool gen_table) { + using namespace alg_kind; + + h->align(64); + h->L(l_table); + + if (gen_table) { + switch (alg_) { + case eltwise_relu: relu_prepare_table(); break; + case eltwise_elu: + case eltwise_tanh: + case eltwise_logistic: + elu_prepare_table(); break; + case eltwise_soft_relu: soft_relu_prepare_table(); break; + case eltwise_abs: abs_prepare_table(); break; + case eltwise_sqrt: sqrt_prepare_table(); break; + case eltwise_linear: linear_prepare_table(); break; + case eltwise_bounded_relu: bounded_relu_prepare_table(); break; + case eltwise_square: break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template struct jit_uni_eltwise_injector_f32; +template struct jit_uni_eltwise_injector_f32; +template struct jit_uni_eltwise_injector_f32; + + +struct jit_args { + const float *from; + const float *for_comparison; + const float *to; + size_t work_amount; +}; + +struct jit_uni_eltwise_kernel_f32 : public c_compatible { + const eltwise_desc_t &desc_; + + void (*ker_)(const jit_args *); + void operator()(const jit_args *args) { assert(ker_); ker_(args); } + + jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc) + : desc_(desc), ker_(nullptr) {} + virtual ~jit_uni_eltwise_kernel_f32() {} + +protected: + bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; } +}; + +/* jit kernels */ +namespace { + +template +struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32) + + void compute_step(bool vectorize, const int uf, const int shift) { + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + uni_vmovups(Vmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } else { + movss(Xmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + movss(Xmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } + } + + if (isa == sse42) { + for (int i = 0; i < uf; i++) { + movups(Vmm(2 * uf + i + 1), Vmm(i + 1)); + mulps(Vmm(2 * uf + i + 1), vmm_ns); + + Vmm mask = Vmm(0); + if (is_bwd()) { + movups(mask, Vmm(uf + i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } else { + movups(mask, Vmm(i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } + blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1)); + } + } else { + for (int i = 0; i < uf; i++) { + vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns); + if (isa == avx2) { + if (is_bwd()) + vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero); + else + vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero); + + vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1), + Vmm(i + 1), vmm_mask); + + } else { + if (is_bwd()) + vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us); + else + vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us); + vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1), + Vmm(i + 1)); + } + } + } + + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1)); + } else { + movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1)); + } + } + } + + jit_uni_relu_kernel_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + assert(desc.alg_kind == alg_kind::eltwise_relu); + assert(isa == sse42 || isa == avx2 || isa == avx512_common); + + Reg64 param = abi_param1; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int loop_dec[] = {simd_w, 1}; + const int uf[] = {1, 1}; + const int shift[] = {cpu_isa_traits::vlen, sizeof(float)}; + const bool loop_vectorize[] = {true, false}; + + this->preamble(); + + mov(reg_from, ptr[param + GET_OFF(from)]); + if (is_bwd()) + mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + + mov(imm_addr64, float2int(desc.alpha)); + movq(xmm_ns, imm_addr64); + uni_vbroadcastss(vmm_ns, xmm_ns); + + uni_vpxor(vmm_zero, vmm_zero, vmm_zero); + + Label loop_label[3]; + + for (int id = 0; id < 2; id++) { + L(loop_label[id]); + cmp(reg_work_amount, uf[id] * loop_dec[id] - 1); + jle(loop_label[id + 1], T_NEAR); + + compute_step(loop_vectorize[id], uf[id], shift[id]); + + add(reg_from, uf[id] * shift[id]); + add(reg_to, uf[id] * shift[id]); + if (is_bwd()) + add(reg_for_comparison, uf[id] * shift[id]); + + sub(reg_work_amount, uf[id] * loop_dec[id]); + jmp(loop_label[id]); + } + + L(loop_label[2]); + this->postamble(); + + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using Vmm = typename utils::conditional3::type; + + Reg64 reg_from = rax; + Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_ns = Xmm(14); + + Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14); + Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15); + + Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12); + Opmask k_mask = Opmask(1); +}; + +template +struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32, + public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32) + + jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + + eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, + desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1)); + + using namespace alg_kind; + + assert(is_bwd() == false); + assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + + preamble(); + + Reg64 param = abi_param1; + mov(reg_from, ptr[param + GET_OFF(from)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + eltwise_injector_->load_table_addr(); + + Label reminder_loop_start, reminder_loop_end; + Label vectorized_loop_start, vectorized_loop_end; + + cmp(reg_work_amount, simd_w); + jl(reminder_loop_start, T_NEAR); + + L(vectorized_loop_start); + + uni_vmovups(vmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(vmm_src.getIdx()); + uni_vmovups(ptr[reg_to], vmm_src); + + add(reg_from, vlen); + add(reg_to, vlen); + + sub(reg_work_amount, simd_w); + cmp(reg_work_amount, simd_w); + jge(vectorized_loop_start, T_NEAR); + + L(vectorized_loop_end); + + L(reminder_loop_start); + + cmp(reg_work_amount, 0); + jle(reminder_loop_end, T_NEAR); + + movss(xmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(xmm_src.getIdx()); + movss(ptr[reg_to], xmm_src); + + add(reg_from, sizeof(float)); + add(reg_to, sizeof(float)); + + dec(reg_work_amount); + jmp(reminder_loop_start, T_NEAR); + + L(reminder_loop_end); + + postamble(); + + eltwise_injector_->prepare_table(); + + ker_ = (decltype(ker_))this->getCode(); + } + + ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; } + +private: + using Vmm = typename utils::conditional3::type; + + const int simd_w = cpu_isa_traits::vlen / sizeof(float); + const int vlen = cpu_isa_traits::vlen; + + Reg64 reg_from = rax; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_src = Xmm(1); + Vmm vmm_src = Vmm(1); + + jit_uni_eltwise_injector_f32 *eltwise_injector_; +}; + +} /* namespace */ + +template +status_t jit_uni_eltwise_fwd_t::pd_t::init() { + using namespace alg_kind; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && utils::everyone_is(data_type::f32, desc()->data_desc.data_type) + && !has_zero_dim_memory() + && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh, + eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, + eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, + eltwise_logistic) + && memory_desc_wrapper(src_md()).is_dense(true) + && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false), + math::eltwise_fwd_preserves_zero(desc()->alg_kind, true)) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template +jit_uni_eltwise_fwd_t::jit_uni_eltwise_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32(desc); break; + default: + kernel_ = new jit_uni_kernel_fwd_f32(desc); + } +} + +template +jit_uni_eltwise_fwd_t::~jit_uni_eltwise_fwd_t() +{ delete kernel_; } + +template +void jit_uni_eltwise_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const size_t nelems = data_d.nelems(true); + + src += data_d.offset0(); + dst += data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &src[start]; + arg.for_comparison = &src[start]; + arg.to = &dst[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template +status_t jit_uni_eltwise_bwd_t::pd_t::init() { + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu) + && src_md()->data_type == data_type::f32 + && !has_zero_dim_memory() + && mayiuse(isa) + && memory_desc_wrapper(src_md()).is_dense() + && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md()) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template +jit_uni_eltwise_bwd_t::jit_uni_eltwise_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32(desc); break; + default: assert(!"unknown eltwise alg_kind"); + } +} + +template +jit_uni_eltwise_bwd_t::~jit_uni_eltwise_bwd_t() +{ delete kernel_; } + +template +void jit_uni_eltwise_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const size_t nelems = data_d.nelems(); + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &diff_dst[start]; + arg.to = &diff_src[start]; + arg.for_comparison = &src[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; +template struct jit_uni_eltwise_fwd_t; +template struct jit_uni_eltwise_bwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp new file mode 100644 index 0000000000..45436b9f46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_ELTWISE_HPP +#define CPU_JIT_UNI_ELTWISE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_eltwise_injector_f32 { + using Vmm = typename utils::conditional3::type; + + jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, + float alpha, float beta, bool save_state = true, + Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : alg_(alg), alpha_(alpha), beta_(beta), h(host) + , save_state_(save_state), p_table(p_table), k_mask(k_mask) + { + using namespace alg_kind; + assert(utils::one_of(isa, sse42, avx2, avx512_common)); + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + } + + // note that eltwise.scale is ignored + jit_uni_eltwise_injector_f32(jit_generator *host, + const post_ops_t::entry_t::eltwise_t &eltwise, + bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, + eltwise.beta, save_state, p_table, k_mask) {} + + void compute_vector_range(size_t start_idx, size_t end_idx); + void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } + void prepare_table(bool gen_table=true); + void load_table_addr() { h->mov(p_table, l_table); } + + const alg_kind_t alg_; + const float alpha_; + const float beta_; + + jit_generator * const h; + + const bool save_state_; + const Xbyak::Reg64 p_table; + const Xbyak::Opmask k_mask; + Xbyak::Label l_table; + +private: + // if only the injector was inherited from jit_generator... + enum { + _cmp_le_os = jit_generator::_cmp_le_os, + _cmp_nle_us = jit_generator::_cmp_nle_us, + _op_floor = jit_generator::_op_floor, + }; + + size_t vlen = cpu_isa_traits::vlen; + + const static size_t preserved_vecs_max = 5; + + size_t vecs_to_preserve = 0; + size_t vecs_count = isa == avx512_common ? 32 : 16; + size_t preserved_vecs_count = 0; + size_t preserved_vec_idxs[preserved_vecs_max] = {0}; + size_t start_idx_tail = 0; + + Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4; + + Xbyak::Address table_val(int index) + { return h->ptr[p_table + index * vlen]; } + + int aux_vecs_count(alg_kind_t alg); + + void compute_body(size_t start_idx, size_t end_idx); + void injector_preamble(size_t start_idx, size_t end_idx); + void injector_preamble_tail(size_t start_idx); + void injector_postamble(); + void assign_regs(); + + void exp_compute_vector(const Vmm &vmm_src); + void relu_compute_vector(const Vmm &vmm_src); + void relu_zero_ns_compute_vector(const Vmm &vmm_src); + void elu_compute_vector(const Vmm &vmm_src); + void tanh_compute_vector(const Vmm &vmm_src); + void square_compute_vector(const Vmm &vmm_src); + void abs_compute_vector(const Vmm &vmm_src); + void sqrt_compute_vector(const Vmm &vmm_src); + void linear_compute_vector(const Vmm &vmm_src); + void bounded_relu_compute_vector(const Vmm &vmm_src); + void soft_relu_compute_vector(const Vmm &vmm_src); + void logistic_compute_vector(const Vmm &vmm_src); + + void relu_prepare_table(); + void elu_prepare_table(); + void soft_relu_prepare_table(); + void abs_prepare_table(); + void sqrt_prepare_table(); + void linear_prepare_table(); + void bounded_relu_prepare_table(); +}; + +struct jit_uni_eltwise_kernel_f32; + +template +struct jit_uni_eltwise_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_fwd_t); + + status_t init(); + }; + + jit_uni_eltwise_fwd_t(const pd_t *apd); + ~jit_uni_eltwise_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +template +struct jit_uni_eltwise_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_bwd_t); + + status_t init(); + }; + + jit_uni_eltwise_bwd_t(const pd_t *apd); + ~jit_uni_eltwise_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp new file mode 100644 index 0000000000..a3ca6273a0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp @@ -0,0 +1,949 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_uni_i8i8_pooling.hpp" + +#include + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::types; +using namespace alg_kind; + +template +struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t) + + struct call_params_t { + const char *src_i8; + const char *dst_i8; + size_t kw_range; + size_t kh_range; + float idivider; + }; + + using Vmm = typename cpu_isa_traits::Vmm; + Xmm xreg(int idx) const { return Xmm(idx); } + Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); } + + // In case of avx2 with data type i8 we need to use + // maskmovdqu instruction which has its destination hardcoded in rdi. + // Windows ABI: abi_param1 is rcx - nothing to do else + // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1 + Reg64 reg_param = rcx; // Our "unified abi_param1" + Reg64 reg_ptr_src_i8 = r8; + Reg64 reg_ptr_dst_i8 = r9; + Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi + + Reg64 ki = r10; + Reg64 kj = r11; + Reg64 reg_kw = r12; + Reg64 reg_kh = r13; + Reg64 c_iter = r14; + + Reg64 aux_reg_src_h = rax; + Reg64 aux_reg_src_w = rbx; + + Reg64 reg_tmp = rdx; + + Reg64 reg_mask = r15; + + Opmask k_cmp_mask = Opmask(7); + + Opmask mask(int idx) { + return Opmask(6 - idx); + } + + // ref to any of XYZ-regs via xreg/yreg/vreg functions + Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp + Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type + Vmm vreg_zeros = vreg(1); + + // only in case of == avx2 + Vmm vreg_mask = vreg(2); // full byte-mask + Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask) + Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately) + Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations + Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails + + enum:int {vidx_base = isa == avx2 ? 4 : 2}; + Vmm base_vr(int idx) const { return vreg(vidx_base + idx); } + + size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); } + size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); } + + /* max pooling */ + Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1] + Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1] + + /* avg pooling */ + // s32 used for processing of s8/u8 data + // thus we need to take into account ratio of sizes s32/i8 = 4 + static constexpr data_type_t avg_proc_dt = data_type::s32; + enum:int { + s32_to_i8_ratio = sizeof(typename prec_traits::type) + / sizeof(typename prec_traits::type), + max_num_ll = s32_to_i8_ratio + }; + Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3] + Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7] + Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11] + + void (*ker_)(const call_params_t *); + jit_pool_conf_t jpp; + + void init_tmp_reg(); + void init_mask(); + + void load_vreg_mask_q(int ll) {}; + + void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src(int jj, int ll, int c_tail); + + void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst(int jj, int ll, int c_tail); + + void compute_avg_step(int ur_c, int c_tail); + void compute_max_op(const int jj); + void compute_max_step(int ur_c, int c_tail); + void compute_step(int ur_c, int c_tail); + + void compute_c_block(); + void generate(); + + static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd); + + jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_) + : jpp(jpp_) { + generate(); + ker_ = reinterpret_cast(const_cast( + getCode())); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_vreg_mask_q(int ll) { + + // extract ll-th part of mask (ll-th QWORD) + vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD + + // Move mask from ll-th pos to 0-th pos + if (ll>0) + vpermq(vreg_mask_q, vreg_mask_q, ll); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) { + vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast(msk)); + } else { + vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask); + } + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) + vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + else + vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto load_i8 = [&](bool is_signed, const Vmm& vr_src) { + + // Need to use mask of tail? + if (masked) { + + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + + // Load by mask from mem into register vr_src + vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q); + + // Conversion s8/u8 -> s32 + if (is_signed) + vpmovsxbd(vr_src, vr_src); + else + vpmovzxbd(vr_src, vr_src); + } else { + + // Load from mem into vr_src with conversion + if (is_signed) + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + else + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + } + }; + + switch (jpp.src_dt) { + case s32: + if (masked) + vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset], + static_cast(msk)); + else + vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]); + break; + case s8: + load_i8(true, vreg_src_s32(jj, ll)); + break; + case u8: + load_i8(false, vreg_src_s32(jj, ll)); + break; + default: assert(!"unsupported src data type"); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_src = masked ? + vreg_src_s32(jj, ll) | mask(ll) : + vreg_src_s32(jj, ll); + + switch (jpp.src_dt) { + case s32: + vmovups(vr_src, ptr[aux_reg_src_w + offset]); + break; + case s8: + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + case u8: + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + default: assert(!"unsupported src data type"); + } +}; + +template +void jit_uni_i8i8_pooling_fwd_ker_t::load_src(int jj, int ll, int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch (jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_max_op(jj, ll, offset, masked, jpp.tail[0]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + int c_block = jpp.c_block; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj)); + break; + case s8: + case u8: { + // Store low half by mask (bytes 0...15) + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vreg_dst(jj), xreg_mask_lo); + + // Do we need to store high half (bytes 16...31) ? + const uint64_t low_mask = (1ULL << (c_block/2))-1; + if (msk & ~low_mask) { + vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1); + add(reg_ptr_maskmovdqu_dst, c_block / 2); + maskmovdqu(vreg_dst(jj), xreg_mask_hi); + } + } break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + case s8: + case u8: + vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk){ + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) { + + // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16} + // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0} + if (is_signed) + vpackssdw(vr_dst, vr_dst, vreg_zeros); + else + vpackusdw(vr_dst, vr_dst, vreg_zeros); + + // Permute qwords to restore original order + // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} + vpermq(vr_dst, vr_dst, 0x58); + + // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8} + // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx} + if (is_signed) + vpacksswb(vr_dst, vr_dst, vreg_zeros); + else + vpackuswb(vr_dst, vr_dst, vreg_zeros); + + }; + + auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) { + + // Conversion s32 -> s8/u8 + s32_to_i8(is_signed, vr_dst); + + // Need to use mask of tail? + if (is_masked) { + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + } + + // store 8 bytes + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vr_dst, xreg_mask_q); + }; + + switch (jpp.dst_dt) { + case s32: + if (masked) { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll)); + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll)); + break; + case s8: + store_i8(true, masked, vreg_dst_s32(jj, ll)); + break; + case u8: + store_i8(false, masked, vreg_dst_s32(jj, ll)); + break; + default: assert(!"unsuppotred dst data_type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_dst = masked ? + vreg_dst_s32(jj, ll) | mask(ll) : + vreg_dst_s32(jj, ll); + + switch (jpp.dst_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case s8: + vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case u8: + vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + default: assert(!"unsupported dst data_type"); + } +} + + +template +void jit_uni_i8i8_pooling_fwd_ker_t::store_dst(int jj, int ll, + int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch(jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported pooling algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) +{ + using namespace data_type; + switch (jpp.src_dt) { + case s32: + vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case s8: + vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case u8: + vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + default: assert(!"unsupported src data type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) +{ + using namespace data_type; + + // Compare + switch (jpp.src_dt) { + case s32: + vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case s8: + vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case u8: + vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + default: assert(!"unsupported src data type"); + } + + // move max values into vreg_dst + if (jpp.src_dt == s32) + vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); + else + vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); +} + + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_step(int ur_c, int c_tail) +{ + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + for (int jj = 0; jj < ur_c; jj++) + vmovups(vreg_dst(jj), vreg_tmp); + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + load_src(jj, 0, c_tail); + compute_max_op(jj); + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) + store_dst(jj, 0, c_tail); +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step(int ur_c, int c_tail) +{ + using namespace data_type; + + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt); + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll)); + uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll)); + } + } + } + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + load_src(jj, ll, c_tail); + vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), + vreg_src_s32(jj, ll)); + } + } + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll)); + vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp); + vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll)); + store_dst(jj, ll, c_tail); + } + } + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_step(int ur_c, int c_tail) { + switch (jpp.alg) { + case pooling_max: + compute_max_step(ur_c, c_tail); break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + compute_avg_step(ur_c, c_tail); break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::compute_c_block(){ + Label l_main_loop; + + int nb_c = jpp.nb_c; + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + int ur_c_tail = jpp.ur_c_tail; + int c_steps = nb_c / ur_c; + int c_tail = jpp.c_tail; + + xor_(c_iter, c_iter); + if (c_steps > 0) { + L(l_main_loop); { + compute_step(ur_c, 0); + add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt()); + add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt()); + inc(c_iter); + cmp(c_iter, c_steps); + jl(l_main_loop, T_NEAR); + } + } + + if (ur_c_tail != 0) { + compute_step(ur_c_tail, c_tail); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { + using namespace data_type; + using cpu_isa = cpu_isa_traits; + + // AVX2 mask initialization: mask stored in Ymm-regs + auto init = [&](uint64_t bit_mask, bool init_mask_q) { + const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t); + + uint64_t vmask[QW_PER_VREG]; + for (size_t i = 0; i < QW_PER_VREG; i++){ + + uint64_t qw_vmask=0ULL; + const size_t DBITS = 8*sizeof_src_dt(); + const uint64_t VMSK = 1ULL << (DBITS-1); + const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS; + for (size_t j = 0; j < D_PER_QW; j++) { + if (bit_mask & 1) + qw_vmask |= VMSK << DBITS * j; + bit_mask >>= 1; + } + vmask[i] = qw_vmask; + } + + // Put QWORDS with target mask into xmm regs + const int xdst_i[QW_PER_VREG] = { + xreg_mask_lo.getIdx(), + xreg_mask_lo.getIdx(), + xreg_mask_hi.getIdx(), + xreg_mask_hi.getIdx() + }; + const int xsrc_i[QW_PER_VREG] = { + vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0} + xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1} + vreg_zeros.getIdx(), + xreg_mask_hi.getIdx() + }; + const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg + + for (size_t i = 0; i < QW_PER_VREG; i++) { + mov(reg_mask, vmask[i]); + vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]); + } + + // Merge Low (xreg_mask_lo alias for vreg_mask.xreg) + // and High (xreg_mask_hi) into full vreg_mask + // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg} + vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1); + + // Keep only low qword of mask in xreg_mask_q + if (init_mask_q) { + mov(reg_mask, vmask[0]); + vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0); + } + }; + + uint64_t tail_mask = (1ULL << jpp.c_tail) - 1; + switch (jpp.alg) { + case pooling_max: + // For "max" we need mask only in case of non-zero tail + if (tail_mask) + init(tail_mask, false); + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + // For "avg" we need mask: + // - s32 - in case of the non-zero tail + // - s8/u8 - irrespective of the tail + switch (jpp.src_dt) { + case s32: + if (tail_mask) + init(tail_mask, false); + break; + case s8: + case u8: + init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0); + break; + default: assert(!"unsupported src data type"); + } + break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { + + for (int ll = 0; ll < max_num_ll; ll++) { + mov(reg_mask, jpp.tail[ll]); + kmovq(mask(ll), reg_mask); + } +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::init_tmp_reg() { + using namespace data_type; + + switch (jpp.alg) { + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]); + movq(xmm_tmp, reg_tmp); + vpbroadcastd(vreg_tmp, xmm_tmp); + break; + case pooling_max: + switch (jpp.src_dt) { + case s32: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + case s8: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + case u8: + mov(reg_tmp, nstl::numeric_limits::lowest()); + break; + default: assert(!"unsupported src data_type"); + } + + movq(xmm_tmp, reg_tmp); + if (jpp.src_dt == s32) + vpbroadcastd(vreg_tmp, xmm_tmp); + else + vpbroadcastb(vreg_tmp, xmm_tmp); + break; + default: assert(!"unsupported pooling algorithm"); + } + +} + +template +void jit_uni_i8i8_pooling_fwd_ker_t::generate() { + preamble(); + +#if !defined(_WIN32) + // Always use rcx as abi_param1 - + // see the note about maskmovdqu near reg_param. + mov(rcx, rdi); +#endif + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[reg_param + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src_i8, src_i8); + READ_PARAM(reg_ptr_dst_i8, dst_i8); + READ_PARAM(reg_kw, kw_range); + READ_PARAM(reg_kh, kh_range); + +# undef READ_PARAM + + uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); + + init_mask(); + + init_tmp_reg(); + + compute_c_block(); + + postamble(); +} + +template +status_t jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + if (!mayiuse(isa)) + return status::unimplemented; + + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d(ppd->src_md()); + const memory_desc_wrapper dst_d(ppd->dst_md()); + + jpp.mb = src_d.dims()[0]; + jpp.c = src_d.dims()[1]; + jpp.ih = src_d.dims()[2]; + jpp.iw = src_d.dims()[3]; + jpp.oh = dst_d.dims()[2]; + jpp.ow = dst_d.dims()[3]; + + jpp.stride_h = pd.strides[0]; + jpp.stride_w = pd.strides[1]; + jpp.kh = pd.kernel[0]; + jpp.kw = pd.kernel[1]; + + jpp.t_pad = pd.padding[0][0]; + jpp.l_pad = pd.padding[0][1]; + + jpp.alg = pd.alg_kind; + + jpp.src_dt = pd.src_desc.data_type; + jpp.dst_dt = pd.dst_desc.data_type; + + // data_type items per one vreg on the + // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 + // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 + int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.src_dt); + + jpp.c_block = simd_w; + jpp.c_tail = jpp.c % jpp.c_block; + jpp.nb_c = jpp.c / jpp.c_block; + jpp.ur_c = 1; + jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + + (jpp.c_tail != 0); + + size_t tail_mask = (1ULL << jpp.c_tail) - 1; + + switch (jpp.alg) { + case pooling_max: + jpp.tail[0] = tail_mask; + jpp.tail[1] = 0; + jpp.tail[2] = 0; + jpp.tail[3] = 0; + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32) + // avx2 : 8, avx512 : 16 + const size_t msk_gran = cpu_isa_traits::vlen / data_type_size(avg_proc_dt); + const size_t msk_msk = (1ULL << msk_gran) - 1; + size_t m = tail_mask; + for (size_t ll = 0; ll < max_num_ll; ll++) { + jpp.tail[ll] = m & msk_msk; + m = m >> msk_gran; + } + break; + } + default: return status::unimplemented; + } + + return status::success; +} + +template +status_t jit_uni_i8i8_pooling_fwd_t::pd_t::jit_conf() { + return jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jpp_, this); +} + +template +jit_uni_i8i8_pooling_fwd_t:: +jit_uni_i8i8_pooling_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) +{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t(pd()->jpp_); } + +template +jit_uni_i8i8_pooling_fwd_t:: +~jit_uni_i8i8_pooling_fwd_t() { delete ker_; } + +template +void jit_uni_i8i8_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC); + auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const auto &jpp = pd()->jpp_; + + parallel_nd(jpp.mb, jpp.oh, jpp.ow, + [&](int n, int oh, int ow) { + const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0); + const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0); + + const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h); + const int kh_end = nstl::min(jpp.kh, + jpp.ih + jpp.t_pad - oh * jpp.stride_h); + const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w); + const int kw_end = nstl::min(jpp.kw, + jpp.iw + jpp.l_pad - ow * jpp.stride_w); + + auto p = typename jit_uni_i8i8_pooling_fwd_ker_t::call_params_t(); + p.src_i8 = &src_i8[ + src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()]; + p.dst_i8 = &dst_i8[ + dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()]; + p.kw_range = (size_t)(kw_end - kw_start); + p.kh_range = (size_t)(kh_end - kh_start); + p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ? + p.kh_range*p.kw_range : jpp.kw*jpp.kh); + + ker_->ker_(&p); + }); +} + +// Explicit instantiation only for supported values. +// +template struct jit_uni_i8i8_pooling_fwd_ker_t; +template struct jit_uni_i8i8_pooling_fwd_t; + +template struct jit_uni_i8i8_pooling_fwd_ker_t; +template struct jit_uni_i8i8_pooling_fwd_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp new file mode 100644 index 0000000000..d757679df5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP +#define CPU_JIT_UNI_I8I8_POOLING_HPP + +#include "c_types_map.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_i8i8_pooling_fwd_ker_t; + +template +struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_i8i8_pooling_fwd_t); + + status_t init() { + bool ok = true + && mayiuse(isa) + && ndims() == 4 + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::one_of(src_md()->data_type, data_type::s32, + data_type::s8, data_type::u8) + && src_md()->data_type == dst_md()->data_type + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*dst_md(), format_tag::nhwc); + if (!ok) return status::unimplemented; + + return jit_conf(); + } + + jit_pool_conf_t jpp_; + + protected: + status_t jit_conf(); + }; + + jit_uni_i8i8_pooling_fwd_t(const pd_t *apd); + ~jit_uni_i8i8_pooling_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_i8i8_pooling_fwd_ker_t *ker_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp new file mode 100644 index 0000000000..2c5a8e8973 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp @@ -0,0 +1,305 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" +#include "jit_uni_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +template +jit_uni_lrn_fwd_t::jit_uni_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) + , ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float K = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, 0), A, K, pk); + ker_first_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, -1), A, K, pk); + ker_last_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_across(H, W, +1), A, K, pk); + } else if (dat_tag == nChw8c && ak == lrn_within_channel) { + /* within channel, local_size (x) local_size */ + A /= ls; /* XXX: why? */ + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw8c_within(H, W, ls), A, K, pk); + } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32( + nchw_across(C, H*W, 0), A, K, pk); + int remind = (H*W) % VECTOR_LENGTH; + if (remind != 0) { + ker_last_ = new jit_uni_lrn_fwd_kernel_f32( + nchw_across(C, H*W, remind), A, K, pk); + } + } else if (true /* XXX: why */) { + ker_ = new jit_uni_lrn_fwd_kernel_f32(nhwc_across(C), A, K, pk); + } +} + +template +jit_uni_lrn_fwd_t::~jit_uni_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +template +void jit_uni_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int HW = pd()->H() * pd()->W(); + const int ls = pd()->desc()->local_size; + + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else if (dat_tag == nChw8c && ak == lrn_within_channel) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + (*ker_)(&args); + }); + } + else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH, + [&](int n, int hw8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH]; + if ((hw8 + 1)*VECTOR_LENGTH > HW) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { // nhwc + parallel_nd(N, HW, [&](int n, int hw) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw * C]; + args.dst = &dst[n*HW*C + hw * C]; + args.scratch = &ws[n*HW*C + hw * C]; + (*ker_)(&args); + }); + } +} + +template +status_t jit_uni_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && is_fwd() + && everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && data_d.dims()[1] >= 2 * VECTOR_LENGTH + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc_.prop_kind == forward_training) ws_md_ = *src_md(); + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && one_of(dat_tag_, nChw8c, nchw, nhwc); + + const int jit_max_local_size = 5; // bigger size triggers too big code size + bool args_ok_within = true + && desc()->alg_kind == lrn_within_channel + && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE + ? jit_max_local_size : MAX_LOCAL_SIZE) + && data_d.dims()[2] >= desc()->local_size + && data_d.dims()[3] >= desc()->local_size + && one_of(dat_tag_, nChw8c); + + return args_ok_across || args_ok_within ? success : unimplemented; +} + +template +jit_uni_lrn_bwd_t::jit_uni_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float B = pd()->desc()->lrn_beta; + + int use_h_parallelizm = 0;// XXX + if (C / VECTOR_LENGTH == 1) { + ker_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, 3), A, B, use_h_parallelizm); + } + else { + ker_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, 0), A, B, use_h_parallelizm); + ker_first_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, -1), A, B, use_h_parallelizm); + ker_last_ = new jit_uni_lrn_bwd_kernel_f32( + nchw8c_across(H, W, +1), A, B, use_h_parallelizm); + } +} + +template +jit_uni_lrn_bwd_t::~jit_uni_lrn_bwd_t() +{ + delete ker_; delete ker_first_; delete ker_last_; +} + +template +void jit_uni_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + int use_h_parallelizm = 0; // XXX + if (use_h_parallelizm) { + parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH + + h*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } +} + +template +status_t jit_uni_lrn_bwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + ws_md_ = *src_md(); + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && utils::one_of(dat_tag_, nChw8c); + + return args_ok_across ? success : unimplemented; +} + +template struct jit_uni_lrn_fwd_t; +template struct jit_uni_lrn_fwd_t; +template struct jit_uni_lrn_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp new file mode 100644 index 0000000000..333cd3396d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_HPP +#define CPU_JIT_UNI_LRN_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_bwd_kernel_f32; + +template +struct jit_uni_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_fwd_t); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_fwd_t(const pd_t *apd); + ~jit_uni_lrn_fwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_fwd_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +template +struct jit_uni_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_bwd_t); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_bwd_t(const pd_t *apd); + ~jit_uni_lrn_bwd_t(); + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_bwd_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp new file mode 100644 index 0000000000..89af47272c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp @@ -0,0 +1,1487 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +////////////////////////////////////////////////////////////////////////////// +// forward kernel +template +void jit_uni_lrn_fwd_kernel_f32::within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk) +{ + vxorps(ysum, ysum, ysum); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + vmovups(ydst, ptr[src]); + vfmadd231ps(ysum, ydst, ydst); + } + else + { + vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]); + vfmadd231ps(ysum, ytmp, ytmp); + } + } + } + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + vmovaps(ytmp, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ytmp); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75 + vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum + vmovups(ptr[dst], ydst); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template +void jit_uni_lrn_fwd_kernel_f32::within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk) +{ + Xbyak::Xmm xtmp_lo = xmm12; + Xbyak::Xmm xtmp_hi = xmm13; + Xbyak::Xmm xsum_lo = xmm8; + Xbyak::Xmm xsum_hi = xmm9; + Xbyak::Xmm xdst_lo = xmm10; + Xbyak::Xmm xdst_hi = xmm11; + Xbyak::Xmm xsum2_lo = xmm14; + Xbyak::Xmm xsum2_hi = xmm15; + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + addps(xsum_lo, xdst_lo); + addps(xsum_hi, xdst_hi); + } + else + { + movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]); + movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]); + mulps(xtmp_lo, xtmp_lo); + mulps(xtmp_hi, xtmp_hi); + addps(xsum_lo, xtmp_lo); + addps(xsum_hi, xtmp_hi); + } + } + } + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + movaps(xtmp_lo, xsum_lo); + movaps(xtmp_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xtmp_lo); + movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi); + } + movaps(xsum2_lo, xsum_lo); + movaps(xsum2_hi, xsum_hi); + mulps(xsum2_lo, xsum_lo); + mulps(xsum2_hi, xsum_hi); + mulps(xsum_lo, xsum2_lo); + mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3; + + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75 + + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + divps(xdst_lo, xsum_lo); + divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum + + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 h = r9; + Xbyak::Reg64 w = r10; + Vmm ysum = Vmm(isa == avx2 ? 9 : 9); + Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10); + Vmm ydst = Vmm(isa == avx2 ? 11 : 11); + Vmm ytmp = Vmm(isa == avx2 ? 12 : 12); + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + if (isa == avx2) { + vbroadcastss(yalpha, xalpha); + } else { + shufps(xalpha, xalpha, 0); + } + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + if (isa == avx2) { + vbroadcastss(yk, xk); + } else { + shufps(xk, xk, 0); + } + + int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1; + + for (int i = 0; i < s2; ++i) + { + Label label_t; + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } + else { + within_body_sse42(-i, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + L(label_t); + if (isa == avx2) { + within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_t, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-i, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + mov(h, J.H - J.size + 1); + Label lrn_loop_h; + L(lrn_loop_h); + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + Label lrn_loop_w; + L(lrn_loop_w); + if (isa == avx2) { + within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(lrn_loop_w, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + dec(h); + cmp(h, 0); + jne(lrn_loop_h, T_NEAR); + + for (int i = J.H - S2; i < J.H; ++i) + { + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -j, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk); + } + } + + mov(w, J.W - J.size + 1); + Label label_b; + L(label_b); + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_b, T_NEAR); + + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + Xbyak::Xmm xsrc_prev = xmm2; + Xbyak::Ymm ysrc = ymm3; + Xbyak::Ymm yc = ymm3; + Xbyak::Xmm xsrc_next = xmm4; + Xbyak::Ymm ya = ymm5; + Xbyak::Ymm yb = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ye = ymm8; + Xbyak::Ymm ysum = ymm9; + Xbyak::Ymm ysum2 = ymm10; + Xbyak::Ymm ydst = ymm11; + Xbyak::Ymm ybase = ymm12; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + if (J.version == -1) + { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(ysrc, ptr[src]); + if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev); + vmovups(ptr[t + 16], ysrc); + if (J.version != +1) vmovups(ptr[t + 48], xsrc_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vmulps(ysum, yc, yc); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya + vfmadd231ps(ysum, yb, yb); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + + vmovaps(ybase, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = ybase^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = ybase^0.75 + vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum + vmovups(ptr[dst], ydst); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + + Xbyak::Xmm xsrc_lo = xmm2; + Xbyak::Xmm xsrc_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xsum_lo = xc_lo; + Xbyak::Xmm xsum_hi = xc_hi; + Xbyak::Xmm xsrc_prev = xmm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xa_lo = xmm8; + Xbyak::Xmm xa_hi = xmm9; + Xbyak::Xmm xb_lo = xmm10; + Xbyak::Xmm xb_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + if (J.version == -1) + { + xorps(xsrc_prev, xsrc_prev); + movups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + xorps(xsrc_next, xsrc_next); + movups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + movups(xsrc_lo, ptr[src]); + movups(xsrc_hi, ptr[src + 4 * sizeof(float)]); + if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) movups(ptr[t + 0], xsrc_prev); + movups(ptr[t + 16], xsrc_lo); + movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi); + if (J.version != +1) movups(ptr[t + 48], xsrc_next); + + movups(xa_lo, ptr[t + 16 - 8]); + movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]); + movups(xb_lo, ptr[t + 16 - 4]); + movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]); + movups(xd_lo, ptr[t + 16 + 4]); + movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[t + 16 + 8]); + movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]); + movaps(xc_lo, xsrc_lo); + movaps(xc_hi, xsrc_hi); + mulps(xsum_lo, xc_lo); + mulps(xsum_hi, xc_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + + movaps(xbase_lo, xsum_lo); + movaps(xbase_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xsum_lo, xsum_lo); + mulps(xsum_hi, xsum_hi); + mulps(xsum_lo, xbase_lo); + mulps(xsum_hi, xbase_hi); // xsum = xbase^3; + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75 + divps(xsrc_lo, xsum_lo); + divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum + movups(ptr[dst], xsrc_lo); + movups(ptr[dst + 4 * sizeof(float)], xsrc_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0, 0 + }; + + Xbyak::Reg64 c = r9; + Xbyak::Ymm ya = ymm2; + Xbyak::Ymm yb = ymm3; + Xbyak::Ymm yc = ymm4; + Xbyak::Ymm yd = ymm5; + Xbyak::Ymm ye = ymm6; + Xbyak::Ymm ysum = ymm7; + Xbyak::Ymm ydst = ymm8; + Xbyak::Ymm ybase = ymm9; + Xbyak::Ymm ymask = ymm10; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + vxorps(ysum, ysum, ysum); + + mov(imm_addr64, reinterpret_cast(&mask[0])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ya, ymask, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast(&mask[1])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yb, ymask, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + vmovups(yc, ptr[src]); + vmovups(yd, ptr[src + 4]); + vmovups(ye, ptr[src + 8]); + vfmadd231ps(ysum, yc, yc); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + vmovups(ptr[dst], ydst); + + vxorps(ysum, ysum, ysum); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + vmovups(ya, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); + vmovups(yb, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vmovups(yc, ptr[src]); + vfmadd231ps(ysum, yc, yc); + + mov(imm_addr64, reinterpret_cast(&mask[2])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yd, ymask, ptr[src + 4]); + vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast(&mask[3])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ye, ymask, ptr[src + 8]); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + vmovups(ptr[dst], ydst); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0, 0 + }; + + static uint32_t store[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r9; + + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xa_lo = xmm2; + Xbyak::Xmm xa_hi = xmm3; + Xbyak::Xmm xb_lo = xmm2; + Xbyak::Xmm xb_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xd_lo = xmm6; + Xbyak::Xmm xd_hi = xmm7; + Xbyak::Xmm xe_lo = xmm8; + Xbyak::Xmm xe_hi = xmm9; + Xbyak::Xmm xsum_lo = xmm10; + Xbyak::Xmm xsum_hi = xmm11; + Xbyak::Xmm xmask_lo = xmm12; + Xbyak::Xmm xmask_hi = xmm13; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + mov(store_addr, reinterpret_cast(&store[0])); + and_(store_addr, -15); + movups(ptr[store_addr], xalpha); + movups(ptr[store_addr + 4 * sizeof(float)], xk); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + mov(imm_addr64, reinterpret_cast(&mask[0])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + andps(xa_lo, xmask_lo); + andps(xa_hi, xmask_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast(&mask[1])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + andps(xb_lo, xmask_lo); + andps(xb_hi, xmask_hi); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + + mov(imm_addr64, reinterpret_cast(&mask[2])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast(&mask[3])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movups(xdst_lo, xsum_lo); + movups(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) {} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) +{ + Xbyak::Ymm ydst = ymm14; + Xbyak::Ymm ybase = ymm15; + + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) + vmaskmovps(ptr[scratch], ymask, ybase); + else + vmovups(ptr[scratch], ybase); + } + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + if (tail != 0) + vmaskmovps(ptr[dst], ymask, ydst); + else + vmovups(ptr[dst], ydst); + + + vfnmadd231ps(ysum, ya, ya); + vmovups(ya, yb); + vmovups(yb, yc); + vmovups(yc, yd); + vmovups(yd, ye); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{ + Xbyak::Xmm xmm_tmp = xmm10; + movaps(xmm_tmp, xtail_lo); + size_t offset = 0; + + if (tail > 4) { + movups(ptr[reg_dst], xtail_lo); + movaps(xmm_tmp, xtail_hi); + offset += 4 * sizeof(float); + tail -= 4; + } + movss(ptr[reg_dst + offset], xmm_tmp); + for (int i = 1; i < tail; i++) + { + psrldq(xmm_tmp, 4); + movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp); + } +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) +{ + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xbase_lo = xmm6; + Xbyak::Xmm xbase_hi = xmm7; + Xbyak::Xmm xtmp_lo = xmm8; + Xbyak::Xmm xtmp_hi = xmm9; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + + // store xe + movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi); + + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + // xdst <- xsum*xalpha+xk + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]); + mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]); + addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) { + nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi); + } + else { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + divps(xtmp_lo, xdst_lo); + divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movaps(xdst_lo, xtmp_lo); + movaps(xdst_hi, xtmp_hi); + + if (tail != 0) { + nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi); + } + else { + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + } + + movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]); + movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + subps(xsum_lo, xa_lo); + subps(xsum_hi, xa_hi); + + // xa <- xb + movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]); + movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi); + + // xb <- xc + movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi); + + // xc <- xd + movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]); + movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi); + + // xd <- xe + movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]); + movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r10; + Xbyak::Ymm ymask = ymm2; + Xbyak::Ymm ye = ymm3; + Xbyak::Ymm ya = ymm4; + Xbyak::Ymm yb = ymm5; + Xbyak::Ymm yc = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ysum = ymm8; + + this->preamble(); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); + vmovups(ymask, ptr[imm_addr64]); + } + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + vxorps(ya, ya, ya); + vxorps(yb, yb, yb); + if (J.tail != 0) + vmaskmovps(yc, ymask, ptr[src + J.HW * 0]); + else + vmovups(yc, ptr[src + J.HW * 0]); + if (J.tail != 0) + vmaskmovps(yd, ymask, ptr[src + J.HW * 4]); + else + vmovups(yd, ptr[src + J.HW * 4]); + + vxorps(ysum, ysum, ysum); + vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + vfmadd231ps(ysum, yd, yd); + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) + vmaskmovps(ye, ymask, ptr[src + J.HW * 8]); + else + vmovups(ye, ptr[src + J.HW * 8]); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vxorps(ye, ye, ye); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0 + }; + + Xbyak::Reg64 c = r10; + + Xbyak::Xmm xmask_lo = xmm2; + Xbyak::Xmm xmask_hi = xmm3; + Xbyak::Xmm xsum_lo = xmm4; + Xbyak::Xmm xsum_hi = xmm5; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + sub(rsp, stack_space_needed); + mov(store_addr, rsp); + and_(store_addr, -15); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + // put alpha and k into store (free up regs) + movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha); + movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + } + // init xa, xb + xorps(xa_lo, xa_lo); + xorps(xa_hi, xa_hi); + xorps(xb_lo, xb_lo); + xorps(xb_hi, xb_hi); + + // read xc, xd + if (J.tail != 0) { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + andps(xc_lo, xmask_lo); + andps(xc_hi, xmask_hi); + } + else { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + } + if (J.tail != 0) { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + } + else { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + } + + // put xa, xb, xc, xd into store to free-up regs + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + } + else { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + } + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + xorps(xe_lo, xe_lo); + xorps(xe_hi, xe_hi); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(rsp, stack_space_needed); + + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +////////////////////////////////////////////////////////////////////////////// +// backward kernel +template +jit_uni_lrn_bwd_kernel_f32::jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2 * A*B) + , use_h_parallelizm(use_h_parallel) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r10; + + Xbyak::Xmm xsrc_prev = xmm1; + Xbyak::Xmm xws_prev = xmm2; + Xbyak::Xmm xdiffdst_prev = xmm3; + Xbyak::Ymm ysrc = ymm4; + Xbyak::Ymm yws = ymm5; + Xbyak::Ymm ydiffdst = ymm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xws_next = xmm8; + Xbyak::Xmm xdiffdst_next = xmm9; + Xbyak::Ymm ya = ymm10; + Xbyak::Xmm xa = xmm10; + Xbyak::Ymm yb = ymm11; + Xbyak::Ymm yd = ymm12; + Xbyak::Ymm ye = ymm13; + Xbyak::Ymm ysum = ymm14; + Xbyak::Ymm ydiffsrc = ymm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(diffdst, ptr[this->param1 + 8]); + mov(workspace, ptr[this->param1 + 16]); + mov(diffsrc, ptr[this->param1 + 24]); + + sub(t, 64); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(ynalphabeta, xnalphabeta); + + bool is_single = J.version == 3; + bool is_first = J.version == -1 || J.version == -2; + bool is_last = J.version == +1 || J.version == -2; + + if (is_first || is_single) { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (is_last || is_single) { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W); + Label lrn_loop; + L(lrn_loop); + { + if (!is_first && !is_single) { + vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]); + vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]); + vmulps(xa, xws_prev, xws_prev); + vmulps(xa, xa, xws_prev); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_prev); + vdivps(xsrc_prev, xsrc_prev, xa); + vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev); + } + + vmovups(ysrc, ptr[src]); + vmovups(yws, ptr[workspace]); + vmovups(ydiffdst, ptr[diffdst]); + vmulps(ya, yws, yws); + vmulps(ya, ya, yws); + vsqrtps(ya, ya); + vsqrtps(ya, ya); + vdivps(ydiffsrc, ydiffdst, ya); + vdivps(ysum, ydiffsrc, yws); + vmulps(ysum, ysum, ysrc); + + if (!is_last && !is_single) { + vmovups(xws_next, ptr[workspace + J.H*J.W * 32]); + vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]); + vmulps(xa, xws_next, xws_next); + vmulps(xa, xa, xws_next); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_next); + vdivps(xsrc_next, xsrc_next, xa); + vdivps(xsrc_next, xsrc_next, xws_next); + vmulps(xdiffdst_next, xdiffdst_next, xsrc_next); + } + + if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev); + vmovups(ptr[t + 16], ysum); + if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vaddps(ysum, ysum, ya); + vmulps(ysrc, ysrc, ynalphabeta); + vaddps(ysum, ysum, yb); + + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vaddps(ysum, ysum, yd); + vaddps(ysum, ysum, ye); + + vfmadd231ps(ydiffsrc, ysum, ysrc); + + vmovups(ptr[diffsrc], ydiffsrc); + + add(src, 32); + add(diffsrc, 32); + add(diffdst, 32); + add(workspace, 32); + + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast(const_cast( + this->getCode())); +} + +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_fwd_kernel_f32; +template struct jit_uni_lrn_bwd_kernel_f32; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp new file mode 100644 index 0000000000..2b3ed43cd4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP +#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 }; + +typedef struct { + const float *src; + float *dst, *scratch; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *scratch; + float *diff_src; +} jit_args_bwd_t; + +struct nchw8c_across { + /* version: + * -1: channels 0..7, + * 1: channels C-8 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct nchw8c_within { + int H, W, size; + nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {} +}; + +struct nchw_across { + int C, HW, tail; + nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {} +}; + +struct nhwc_across { + int C; + nhwc_across(int c) : C(c) {} +}; + +template +struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 dst = r8; + Xbyak::Reg64 scratch = rdx; + Xbyak::Reg64 imm_addr64 = rbx; + Xbyak::Reg64 store_addr = rbp; + + Xbyak::Xmm xalpha = xmm0; + Xbyak::Ymm yalpha = ymm0; + Xbyak::Xmm xk = xmm1; + Xbyak::Ymm yk = ymm1; + + float alpha; + float k; + + int stack_space_needed = 11 * 4 * sizeof(float) + 16; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32) + + /* cpu specific part */ + using Vmm = typename utils::conditional::type; + + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk); + void within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk); + + + void nchw_body(int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum); + void nchw_body_sse42(int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); + void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst, + Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi); + + void operator()(jit_args_fwd_t *arg) { ker(arg); } + void(*ker)(jit_args_fwd_t *); +}; + +template +struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 diffsrc = r8; + Xbyak::Reg64 diffdst = r9; + Xbyak::Reg64 workspace = rdx; + Xbyak::Reg64 imm_addr64 = rsi; + + Xbyak::Xmm xnalphabeta = xmm0; + Xbyak::Ymm ynalphabeta = ymm0; + + float nalphabeta; + + int use_h_parallelizm; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32) + + jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void operator()(jit_args_bwd_t *arg) { ker(arg); } + void(*ker)(jit_args_bwd_t *); +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp new file mode 100644 index 0000000000..bf8e609d23 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp @@ -0,0 +1,699 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "cpu_pooling_pd.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; +using namespace alg_kind; + +#define GET_OFF(field) offsetof(jit_pool_call_s, field) + +template +status_t jit_uni_pool_kernel_f32::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d( + ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); + const memory_desc_wrapper dst_d( + ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); + + bool args_ok = true + && mayiuse(isa) + && utils::one_of(pd.alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding); + if (!args_ok) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + const int ndims = src_d.ndims(); + + jpp.ndims = ndims; + jpp.mb = src_d.dims()[0]; + + jpp.c = utils::rnd_up(src_d.dims()[1], simd_w); + if (jpp.c > src_d.padded_dims()[1]) + return status::unimplemented; + + jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jpp.ih = src_d.dims()[ndims-2]; + jpp.iw = src_d.dims()[ndims-1]; + jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jpp.oh = dst_d.dims()[ndims-2]; + jpp.ow = dst_d.dims()[ndims-1]; + + jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1; + jpp.stride_h = pd.strides[ndims-4]; + jpp.stride_w = pd.strides[ndims-3]; + jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; + jpp.kh = pd.kernel[ndims-4]; + jpp.kw = pd.kernel[ndims-3]; + + jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0; + jpp.t_pad = pd.padding[0][ndims-4]; + jpp.l_pad = pd.padding[0][ndims-3]; + + jpp.alg = pd.alg_kind; + + jpp.is_training = pd.prop_kind == prop_kind::forward_training; + jpp.is_backward = pd.prop_kind == prop_kind::backward_data; + jpp.ind_dt = ppd->workspace_md() + ? ppd->workspace_md()->data_type : data_type::undef; + + jpp.simple_alg = jpp.is_training + || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); + + jpp.c_block = simd_w; + + jpp.nb_c = jpp.c / jpp.c_block; + if (jpp.alg == pooling_max) { + jpp.ur_w = isa == avx512_common ? 16 : 4; + if (jpp.is_training) + jpp.ur_w = isa == avx512_common ? 9 : 3; + else if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 6 : 3; + } else { + if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 12 : 6; + else + jpp.ur_w = isa == avx512_common ? 24 : 12; + } + if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow; + if (jpp.l_pad > jpp.ur_w) return status::unimplemented; + + jpp.ur_w_tail = jpp.ow % jpp.ur_w; + + return status::success; +} + +template +inline void jit_uni_pool_kernel_f32::maybe_recalculate_divisor(int jj, + int ur_w, int pad_l, int pad_r) { + if (jpp.alg == pooling_avg_exclude_padding) { + int kw = jpp.kw; + int stride_w = jpp.stride_w; + + int non_zero_kw = kw; + non_zero_kw -= nstl::max(0, pad_l - jj*stride_w); + non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w); + + if (non_zero_kw != prev_kw) { + mov(tmp_gpr, float2int((float)non_zero_kw)); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h); + prev_kw = non_zero_kw; + } + } +} + +template +inline void jit_uni_pool_kernel_f32::avg_step(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + if (jpp.is_backward) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + } else { + uni_vpxor(vreg(jj), vreg(jj), vreg(jj)); + } + } + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + if (jpp.is_backward) { + uni_vmovups(vreg(ur_w+jj), + ptr[aux_reg_input + input_offset]); + uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj)); + uni_vmovups(vmmword[aux_reg_input + input_offset], + vreg(ur_w+jj)); + } else { + uni_vaddps(vreg(jj), vreg(jj), + ptr[aux_reg_input + input_offset]); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + if (!jpp.is_backward) { + for (int jj = 0; jj < ur_w; jj++) { + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], + vreg(jj)); + } + } +} + +template +inline void jit_uni_pool_kernel_f32::max_step_fwd(int ur_w, int pad_l, + int pad_r) { + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + mov(tmp_gpr, float2int(nstl::numeric_limits::lowest())); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), vmm_tmp); + if (jpp.is_training) + uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj)); + } + if (jpp.is_training) + { + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + } + + if (jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + movups(vmm_mask, vreg(jj)); + cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os); + blendvps(vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + blendvps(vreg(2*ur_w+jj), vmm_k_offset); + } else if (isa == avx) { + vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj), + _cmp_lt_os); + vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj), + vreg(3*ur_w+jj)); + if (jpp.is_training) + vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), + vmm_k_offset, vreg(3*ur_w+jj)); + } else { + vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os); + vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + vblendmps(vreg(2*ur_w+jj) | k_store_mask, + vreg(2*ur_w+jj), vmm_k_offset); + } + } + if (jpp.is_training) { + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + if (jpp.is_training) { + mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, xmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj)); + if (jpp.is_training) { + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + + auto x = xreg(2 * ur_w + jj); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + for (int i = 0; i < 4; ++i) + pextrb(ptr[reg_index + step_index + i], x, 4*i); + } else if (isa == avx) { + auto y = yreg(2 * ur_w + jj); + if (jj == 0) { + movd(xmm_tmp, reg_shuf_mask); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (mayiuse(avx2)) { + vpshufb(y, y, vmm_tmp); + movd(ptr[reg_index + step_index], x); + vperm2i128(y, y, y, 0x1u); + movd(ptr[reg_index + step_index + 4], x); + } else { + Xmm t(vmm_mask.getIdx()); + vextractf128(t, y, 0); + vpshufb(t, t, xmm_tmp); + movd(ptr[reg_index + step_index], t); + vextractf128(t, y, 1); + vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] + movd(ptr[reg_index + step_index + 4], t); + } + } else { + auto v = vreg(2 * ur_w + jj); + vpmovusdb(x, v); + vmovups(ptr[reg_index + step_index], v | k_index_mask); + } + } else { + uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj)); + } + } + } +} + +template +inline void jit_uni_pool_kernel_f32::max_step_bwd(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + movd(xreg(ur_w+jj), ptr[reg_index + step_index]); + pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } else if (isa == avx) { + movq(xreg(ur_w+jj), ptr[reg_index + step_index]); + if (!mayiuse(avx2)) { + avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp); + } else { + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + vmovups(vreg(ur_w+jj) | k_index_mask, + ptr[reg_index + step_index]); + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]); + } + } + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + push(dst_ptr); + } + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + mov(dst_ptr, aux_reg_input); + add(dst_ptr, input_offset); + + movups(vreg(3*ur_w+jj), vreg(ur_w+jj)); + pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset); + addps(vreg(2*ur_w+jj), vreg(jj)); + maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj)); + } else if (isa == avx) { + if (mayiuse(avx2)) { + vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset); + } else { + avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp); + } + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj)); + vmaskmovps(vmmword[aux_reg_input + input_offset], + vreg(3*ur_w+jj), vreg(2*ur_w+jj)); + } else { + vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset); + vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj)); + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp); + vmovups(vmmword[aux_reg_input + + sizeof(float)*aux_input_offset], vreg(2*ur_w+jj)); + } + } + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + + mov(tmp_gpr, reg_kd_pad_shift); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, vmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + pop(dst_ptr); + } + pop(reg_output); + pop(reg_input); + } +} + +template +void jit_uni_pool_kernel_f32::maybe_zero_diff_src() { + assert(jpp.c_block * sizeof(float) % cpu_isa_traits::vlen == 0); + Label l_skip, l_zero; + + auto reg_oh = tmp_gpr; + mov(reg_oh, ptr[reg_param + GET_OFF(oh)]); + cmp(reg_oh, 0); + jz(l_skip, T_NEAR); + + if (jpp.ndims == 5) { + mov(zero_size, ptr[reg_param + GET_OFF(oh)]); + mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float)); + imul(zero_size, tmp_gpr); + } + + auto vzero = vmm_tmp; + uni_vpxor(vzero, vzero, vzero); + + auto reg_off = tmp_gpr; + xor_(reg_off, reg_off); + + L(l_zero); + { + const int dim = jpp.iw * jpp.c_block * sizeof(float); + for (int i = 0; i < dim; i += cpu_isa_traits::vlen) + uni_vmovups(ptr[reg_input + reg_off + i], vzero); + add(reg_off, dim); + if (jpp.ndims == 5) cmp(reg_off, zero_size); + else cmp(reg_off, jpp.ih * dim); + jl(l_zero, T_NEAR); + } + + L(l_skip); +} + +template +void jit_uni_pool_kernel_f32::generate() { + + this->preamble(); + + int ow = jpp.ow; + int iw = jpp.iw; + int kw = jpp.kw; + int kh = jpp.kh; + int ur_w = jpp.ur_w; + int c_block = jpp.c_block; + int stride_w = jpp.stride_w; + int l_pad = jpp.l_pad; + int ur_w_tail = jpp.ur_w_tail; + + int n_oi = ow / ur_w; + + prev_kw = 0; + + int vlen = cpu_isa_traits::vlen; + +#if defined(_WIN32) + // Always mimic the Unix ABI (see the note about maskmovdqu in the header + // file). + xor_(rdi, rcx); + xor_(rcx, rdi); + xor_(rdi, rcx); +#endif + + mov(reg_input, ptr[reg_param + GET_OFF(src)]); + mov(reg_output, ptr[reg_param + GET_OFF(dst)]); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + mov(reg_index, ptr[reg_param + GET_OFF(indices)]); + mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]); + mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]); + mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); + + if (jpp.is_backward) + maybe_zero_diff_src(); + + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { + mov(tmp_gpr, 1); + movq(xmm_one, tmp_gpr); + uni_vpbroadcastd(vmm_one, xmm_one); + + if (isa == avx) { + mov(reg_shuf_mask, 0x0c080400); + } else if (isa >= avx512_common) { + mov(tmp_gpr.cvt32(), 0x000f); + kmovw(k_index_mask, tmp_gpr.cvt32()); + } + } + + int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1)); + int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (jpp.alg == pooling_avg_exclude_padding) { + movq(xmm_ker_area_h, reg_ker_area_h); + uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h); + } + + if (jpp.alg == pooling_avg_include_padding) { + mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd))); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) { + step(ur_w, l_pad, r_pad1); + } else { + step(ur_w, l_pad, 0); + } + + if (isa == sse42) { + if (n_oi < 0 && r_pad1 > 0) { + step_high_half(ur_w, l_pad, r_pad1); + } else { + step_high_half(ur_w, l_pad, 0); + } + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + xor_(oi_iter, oi_iter); + if (n_oi > 0) { + Label ow_loop; + L(ow_loop); { + step(ur_w, 0, 0); + + if (isa == sse42) { + step_high_half(ur_w, 0, 0); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + } + + if (r_pad1 > 0 && n_oi >= 0) { + step(ur_w, 0, r_pad1); + + if (isa == sse42) { + step_high_half(ur_w, 0, r_pad1); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + if (ur_w_tail != 0) { + step(ur_w_tail, 0, r_pad); + + if (isa == sse42) { + step_high_half(ur_w_tail, 0, r_pad); + } + } + + this->postamble(); +} + +template struct jit_uni_pool_kernel_f32; +template struct jit_uni_pool_kernel_f32; // implements both and +template struct jit_uni_pool_kernel_f32; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp new file mode 100644 index 0000000000..992b526587 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp @@ -0,0 +1,192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_POOL_KERNEL_F32_HPP +#define JIT_UNI_POOL_KERNEL_F32_HPP + +#include + +#include "c_types_map.hpp" +#include "pooling_pd.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template +struct jit_uni_pool_kernel_f32: public jit_generator { + jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + jit_pool_conf_t jpp; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) + + void operator()(jit_pool_call_s *arg) { jit_ker(arg); } + static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); + +private: + using Vmm = typename utils::conditional3::type; + Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } + Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } + + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx) ? yword : zword; + + Xmm vmm_mask = Xmm(0); + Xmm xmm_ker_area_h = Xmm(2); + Xmm xmm_one = Xmm(2); + Xmm xmm_tmp = Xmm(3); + + Vmm vmm_ker_area_h = Vmm(2); + Vmm vmm_one = Vmm(2); + Vmm vmm_tmp = Vmm(3); + + Vmm vmm_k_offset = Vmm(1); + + Opmask k_index_mask = Opmask(6); + Opmask k_store_mask = Opmask(7); + + // Here be some (tame) dragons. This kernel does not follow the regular + // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu + // instruction which has its destination hardcoded in rdi. Therefore: + // - all registers are hardcoded + // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI + // + // While this is only required by the backward pass, the quirk above + // is applied to the forward pass as well to keep things simpler. + + using reg64_t = const Xbyak::Reg64; + reg64_t reg_param = rdi; // Always mimic the Unix ABI + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t reg_index = r10; + reg64_t reg_output = r12; + reg64_t reg_kd_pad_shift = r13; + reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu + + reg64_t kj = r14; + reg64_t oi_iter = r15; + reg64_t reg_kh = rax; + reg64_t reg_k_shift = rbx; + reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above + reg64_t reg_ker_area_h = rdx; + + reg64_t zero_size = r15; + reg64_t ki = r12; + reg64_t aux_reg_input_d = r8; + + Xbyak::Reg32 reg_shuf_mask = esi; + + int prev_kw; + void (*jit_ker)(jit_pool_call_s *); + + void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); + void avg_step(int ur_w, int pad_l, int pad_r); + void max_step_fwd(int ur_w, int pad_l, int pad_r); + void max_step_bwd(int ur_w, int pad_l, int pad_r); + + void maybe_zero_diff_src(); + + void step(int ur_w, int pad_l, int pad_r) { + if (jpp.alg == alg_kind::pooling_max) { + if(jpp.is_backward) + max_step_bwd(ur_w, pad_l, pad_r); + else + max_step_fwd(ur_w, pad_l, pad_r); + } + else + avg_step(ur_w, pad_l, pad_r); + } + + void step_high_half(int ur_w, int pad_l, int pad_r) { + add(reg_input, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + if (jpp.alg == alg_kind::pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, types::data_type_size(jpp.ind_dt) * 4); + + step(ur_w, pad_l, pad_r); + } + + void generate(); + + void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + assert(y0.getIdx() != x1.getIdx()); + vextractf128(xtmp, y0, 0); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 0); + vextractf128(xtmp, y0, 1); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + paddd(x0, x1); + } + + void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + Xmm x0(y0.getIdx()); + pshufd(xmm_tmp, x1, 1); + pmovzxbd(x0, x1); + pmovzxbd(xmm_tmp, xmm_tmp); + vinsertf128(y0, y0, xmm_tmp, 1); + } + + void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + pmovzxbd(x0, x1); + } + + void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { + assert(y0.getIdx() != y1.getIdx()); + assert(y0.getIdx() != y2.getIdx()); + Xmm x0(y0.getIdx()); + Xmm x2(y2.getIdx()); + vextractf128(x0, y1, 1); + vextractf128(xtmp, y2, 1); + pcmpeqd(xtmp, x0); + vextractf128(x0, y1, 0); + pcmpeqd(x0, x2); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { + assert(false /*function should not be used*/); + pcmpeqd(x0, x1); + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp new file mode 100644 index 0000000000..afbcf996d8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" + +#include "jit_uni_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = oh == 0; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.oh, + [&](int n, int b_c, int oh) { + ker(n, b_c, oh); + }); +} + +template +void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, id, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh + od == 0); + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad) + -jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow); + } + }); +} + +template +void jit_uni_pooling_bwd_t::execute_backward(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh == 0); + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, oh); + } + }); +} + +template +void jit_uni_pooling_bwd_t::execute_backward_3d(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow, int zero_size, int kd) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = zero_size; + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh + + kd * jpp.kw * jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + (*kernel_)(&arg); + }; + + if (jpp.simple_alg) { + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + int zero_s = jpp.stride_d - d_t_overflow - (nstl::max( + jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + (oh == 0) ? zero_s : 0, 0); + } + }); + } else { + ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c + * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw; + + parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; }); + + for (int kd = 0; kd < jpp.kd; ++kd) { + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int od = 0; od < jpp.od; ++od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + if (kd >= jpp.kd - d_t_overflow - d_b_overflow) + continue; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + 0, kd); + } + } + }); + } + } +} + + +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; +template struct jit_uni_pooling_fwd_t; +template struct jit_uni_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp new file mode 100644 index 0000000000..57bebacdee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp @@ -0,0 +1,182 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_POOLING_HPP +#define CPU_JIT_UNI_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct jit_uni_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return jit_uni_pool_kernel_f32::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() == 4 + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } + + ~jit_uni_pooling_fwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE); + + if (pd()->ndims() == 5) + execute_forward_3d(src, dst, ws); + else + execute_forward(src, dst, ws); + + return status::success; + } + +private: + void execute_forward(const data_t *src, data_t *dst, char *indices) const; + void execute_forward_3d(const data_t *src, data_t *dst, + char *indices) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32 *kernel_; +}; + +template +struct jit_uni_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + diff_src_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return jit_uni_pool_kernel_f32::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } + + ~jit_uni_pooling_bwd_t() { delete kernel_; } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + if (pd()->ndims() == 5) + execute_backward_3d(diff_dst, ws, diff_src); + else + execute_backward(diff_dst, ws, diff_src); + + return status::success; + } + +private: + void execute_backward(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + void execute_backward_3d(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32 *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp new file mode 100644 index 0000000000..98796503b7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp @@ -0,0 +1,1006 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "jit_uni_reorder.hpp" + +#include "jit_generator.hpp" + +// #define TR_DEBUG +#if defined(TR_DEBUG) +#define DEBUg(...) do { __VA_ARGS__ } while (0) +#else +#define DEBUg(...) +#endif +#define DEBUG(...) DEBUg(__VA_ARGS__) + +#ifdef _WIN32 +/* seems like s_addr is a reserved macro on Windows */ +#undef s_addr +#endif + +using namespace Xbyak; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ +const size_t ker_prb_size_min = 64; + +/* kernel */ +struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) + + enum { + len_unroll_max = 256, + ndims_jit_loop_max = 3, + }; + + struct simple_impl_desc_t { + int ndims_full_unroll; + int len_last_dim_unroll; + int len_unroll; + }; + + static bool simple_impl_desc_init(const prb_t &prb, + simple_impl_desc_t *desc) { + const int ndims = prb.ndims; + + int ndims_full_unroll = 0; + int len_last_dim_unroll = 1; + int len_unroll = 1; + + for (int d = 0; d < ndims; ++d) { + auto &node = prb.nodes[d]; + if (len_unroll * node.n <= len_unroll_max) { + ndims_full_unroll++; + len_unroll *= node.n; + } else { + len_last_dim_unroll = len_unroll_max / len_unroll; + while (node.n % len_last_dim_unroll) + --len_last_dim_unroll; + len_unroll *= len_last_dim_unroll; + break; + } + } + + if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) + return false; + + if (desc) { + desc->ndims_full_unroll = ndims_full_unroll; + desc->len_last_dim_unroll = len_last_dim_unroll; + desc->len_unroll = len_unroll; + } + + return true; + } + + static bool applicable(const prb_t &p) { + using namespace data_type; + + bool ok = true + && p.ndims > 0 + && utils::one_of(p.itype, f32, s32, s8, u8) + && utils::one_of(p.otype, f32, s32, s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ + && simple_impl_desc_init(p, nullptr) + && mayiuse(sse42) + && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype), + mayiuse(avx)); + if (!ok) return false; + + const ptrdiff_t max_stride = (1LL<<31) - 1; + for (int d = 0; d < p.ndims; ++d) { + const ptrdiff_t cms = max_stride / p.nodes[d].n; + bool strides_ok = true + && p.nodes[d].is < cms / (int)data_type_size(p.itype) + && p.nodes[d].os < cms / (int)data_type_size(p.otype); + if (!strides_ok) return false; + } + + return true; + } + + int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; } + int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; } + int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; } + int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } + + Address i_addr(int i_off) + { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; } + + Address o_addr(int o_off) + { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; } + + Address s_addr(int s_off) + { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; } + + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, + int &i_off, int &o_off, int &s_off, int step_size = 1) { + i_off = prev_i_off; + o_off = prev_o_off; + s_off = prev_s_off; + + if (off == 0) return; + + int start_dim = 0, dims_prod = 1; + for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) + dims_prod *= n(start_dim); + assert(start_dim < prb_.ndims); + off /= step_size; + + for (int d = start_dim; d < prb_.ndims; ++d) { + i_off += is(d); + o_off += os(d); + s_off += ss(d); + + if (off % n(d)) break; + + i_off += - n(d) * is(d); + o_off += - n(d) * os(d); + s_off += - n(d) * ss(d); + off /= n(d); + + if (off == 0) break; /* FIXME: is it really required? */ + } + } + + void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, + int step_size = 1) { + int dummy = 0; + step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, + step_size); + } + + void tr8x8_avx2(int i_off, int o_off) { + for (int i = 0; i < 8; i++) + vmovups(Ymm(i), i_addr(i_off + i * 8)); + + for (int i = 0; i < 8 / 2; i++) { + vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1)); + vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); + } + + const unsigned int lfloat = 0x44; + const unsigned int ufloat = 0xee; + for (int i = 0; i < 8 / 2; i++) { + int j = i % 2 == 0 ? 8 + i : i - 1; + vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); + vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); + } + + const unsigned int lquad = 0x20; + for (int i = 0; i < 8 / 2; i++) + vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad); + + const unsigned int uquad = 0x31; + for (int i = 8 / 2; i < 8; i++) + vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad); + + for (int i = 0; i < 8; i++) + vmovups(o_addr(o_off + i * 8), Ymm(i)); + } + + bool process_unroll_tr8x8(int len) { + bool can_do = true + && mayiuse(avx2) + && prb_.ndims >= 2 + && utils::everyone_is(4, itype_sz, otype_sz) + && utils::everyone_is(8, n(0), n(1)) + && utils::everyone_is(1, os(0), is(1)) + && utils::everyone_is(8, os(1), is(0)) + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + const int step_size = n(0) * n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); + tr8x8_avx2(i_off, o_off); + } + + return true; + } + + template + bool process_direct_copy(int len) { + using namespace data_type; + + using Vmm = typename cpu_isa_traits::Vmm; + const int simd_w = cpu_isa_traits::vlen / itype_sz; + + bool can_do = true + && mayiuse(isa) + && utils::everyone_is(1, os(0), is(0)) + && (false + || prb_.itype == prb_.otype + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32) + ) + && len % simd_w == 0 + && n(0) % len == 0 + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + for (int off = 0; off < len;) { + const int unroll = nstl::min(16, (len - off) / simd_w); + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); + + if (prb_.itype != prb_.otype) { + for (int ur = 0; ur < unroll; ++ur) { + if (prb_.itype == s32 && prb_.otype == f32) + uni_vcvtdq2ps(Vmm(ur), Vmm(ur)); + else if (prb_.itype == f32 && prb_.otype == s32) + uni_vcvtps2dq(Vmm(ur), Vmm(ur)); + else assert(!"unreachable"); + } + } + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur)); + + off += unroll * simd_w; + } + + return true; + } + + void process_unroll_generic_step(int reg_unroll, const int *i_off, + const int *o_off, const int *s_off) { + using namespace data_type; + + auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) { + Xmm dst_pure = Xmm(dst.getIdx()); + switch (idt) { + case f32: + if (src.isMEM() || src.getIdx() != dst.getIdx()) + vmovups(dst, src); + break; + case s32: vcvtdq2ps(dst, src); break; + case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + default: assert(!"unreachable"); + } + }; + + auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { + switch (odt) { + case s32: + if (idt == f32) vcvtps2dq(xmm, xmm); + else if (idt == s8) vpmovsxbd(xmm, xmm); + else if (idt == u8) vpmovzxbd(xmm, xmm); + break; + case s8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmovsdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpacksswb(xmm, xmm, xmm_zero); + } + } + if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); + break; + case u8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmaxsd(xmm, xmm, xmm_zero); + vpmovusdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpackuswb(xmm, xmm, xmm_zero); + } + } + if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); + break; + default: assert(!"unreachable"); + } + }; + + auto load = [=](const Xmm &xmm, const Address &addr, int size) { + switch (size) { + case 16: movups(xmm, addr); break; + case 4: movss(xmm, addr); break; + case 1: pinsrb(xmm, addr, 0x0); break; + default: assert(!"unreachable"); + } + }; + + auto store = [=](const Address &addr, const Xmm &xmm, int size) { + switch (size) { + case 16: movups(addr, xmm); break; + case 4: movss(addr, xmm); break; + case 1: pextrb(addr, xmm, 0x0); break; + default: assert(!"unreachable"); + } + }; + + /* check whether loading 4 values at once is possible */ + bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (i_off[ur] != i_off[ur - 1] + 1) + can_load_xmm = false; + const int load_step = can_load_xmm ? 4 : 1; + + /* check whether storing 4 values at once is possible */ + bool can_store_xmm = reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (o_off[ur] != o_off[ur - 1] + 1) + can_store_xmm = false; + const int ur_step = can_store_xmm ? 4 : 1; + + const bool interim_f32 = false + || utils::one_of(f32, prb_.itype, prb_.otype) + || prb_.scale_type != scale_type_t::NONE + || prb_.beta != 0.f; + + if (!can_load_xmm && can_store_xmm) { + assert(ur_step == 4); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + for (int r = 0; r < ur_step; ++r) { + if (itype_sz == 4) + pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r); + else + pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r); + } + } + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); + } + + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); + for (int ur = 0; ur < reg_unroll; ur += cvt_step) + cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); + } + + if (can_load_xmm && !can_store_xmm) { + const bool fast_return = true // transposition on the fly + && prb_.scale_type != scale_type_t::MANY + && prb_.beta == 0.f; + if (fast_return) { + for (int ur = 0; ur < reg_unroll; ur += load_step) { + if (prb_.scale_type == scale_type_t::COMMON) + mulps(Xmm(ur), xmm_scale); + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, + interim_f32 ? f32 : prb_.itype); + for (int r = 0; r < load_step; ++r) { + if (otype_sz == 4) + pextrd(o_addr(o_off[ur + r]), Xmm(ur), r); + else + pextrb(o_addr(o_off[ur + r]), Xmm(ur), r); + } + } + return; + } + + /* scatter elements of xmm into 4 xmms */ + if (itype_sz == 4 || interim_f32) { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } + } + + /* scale and beta processing */ + if (can_store_xmm) { + /* xmm <-- scale * xmm[:] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulps(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + enum class scale_load_type_t { bcast, load, gather }; + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + scale_load_type_t scale_load_type = + scale_load_type_t::bcast; // the best case + + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 0) + scale_load_type = scale_load_type_t::load; + + if (scale_load_type == scale_load_type_t::bcast) { + movss(xmm_scale, s_addr(s_off[ur])); + shufps(xmm_scale, xmm_scale, 0x0); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // bcast doesn't work, the next try -- load + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 1) + scale_load_type = scale_load_type_t::gather; + + if (scale_load_type == scale_load_type_t::load) { + movups(xmm_scale, s_addr(s_off[ur])); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // load doesn't work as well + // so gather the scale factors one by one + for (int r = ur; r < ur + ur_step; ++r) + pinsrd(xmm_scale, s_addr(s_off[r]), r - ur); + mulps(Xmm(ur), xmm_scale); + } + } + + /* dst <-- beta * dst + xmm[:] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + /* non VEX instructions do not support unaligned + * memory for instructions other than movups. */ + if (mayiuse(avx)) { + vaddps(Xmm(ur), o_addr(o_off[ur])); + } else { + /* register xmm(1) is unused */ + movups(Xmm(1), o_addr(o_off[ur])); + addps(Xmm(ur), Xmm(1)); + } + } else { + cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); + vaddps(Xmm(ur), Xmm(1)); + } + } + } + } else { + /* xmm[0] <-- scale * xmm[0] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulss(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + mulss(Xmm(ur), s_addr(s_off[ur])); + } + } + + /* dst <-- beta * dst + xmm[0] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + addss(Xmm(ur), o_addr(o_off[ur])); + } else { + if (prb_.otype == s32) { + vmovss(xmm_tmp, o_addr(o_off[ur])); + } else if (utils::one_of(prb_.otype, s8, u8)) { + pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0); + } else { + assert(!"unsupported o_type"); + } + cvt2ps(xmm_tmp, xmm_tmp, prb_.otype); + addps(Xmm(ur), xmm_tmp); + } + } + } + } + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); + store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz); + } + } + + void process_unroll_generic(int len) { + const int blk = 8; + + int i_off[2 * blk] = {0}; + int o_off[2 * blk] = {0}; + int s_off[2 * blk] = {0}; + + int curr = 0; // will switch between 0 and 1 + + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; + + /* compute offsets */ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { + const int ur_c = curr * blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur + step(off + ur, + i_off[ur_p], o_off[ur_p], s_off[ur_p], + i_off[ur_c], o_off[ur_c], s_off[ur_c]); + } + + process_unroll_generic_step(reg_unroll, i_off + curr * blk, + o_off + curr * blk, s_off + curr * blk); + + curr = 1 - curr; + } + } + + void loop_begin(Label &l, Reg64 reg_cnt, int len) { + mov(reg_cnt, len); + L(l); + } + + void loop_end(Label &l, Reg64 reg_cnt, int len, + int i_step, int o_step, int s_step) { + add(reg_off_in, i_step * itype_sz); + add(reg_off_out, o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + add(reg_off_scale, s_step * stype_sz); + dec(reg_cnt); + jnz(l); + + sub(reg_off_in, len * i_step * itype_sz); + sub(reg_off_out, len * o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + sub(reg_off_scale, len * s_step * stype_sz); + } + + bool simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; + + const int nfu = d.ndims_full_unroll; + const int ldu = d.len_last_dim_unroll; + const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; + assert(n_jit_loops <= ndims_jit_loop_max); + + xor_(reg_off_in, reg_off_in); + xor_(reg_off_out, reg_off_out); + if (prb_.scale_type == scale_type_t::MANY) + xor_(reg_off_scale, reg_off_scale); + + Label l_loop[3]; + Reg64 reg_cnt[3] = {r15, r14, r13}; + + if (n_jit_loops > 2) + loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); + + if (n_jit_loops > 1) + loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); + + if (n_jit_loops > 0) + loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); + + const bool optimized = false + || process_direct_copy(d.len_unroll) + || process_direct_copy(d.len_unroll) + || process_unroll_tr8x8(d.len_unroll); + if (!optimized) + process_unroll_generic(d.len_unroll); + + if (n_jit_loops > 0) + loop_end(l_loop[0], reg_cnt[0], + n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu, + ss(nfu + 0) * ldu); + + if (n_jit_loops > 1) + loop_end(l_loop[1], reg_cnt[1], + n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1)); + + if (n_jit_loops > 2) + loop_end(l_loop[2], reg_cnt[2], + n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2)); + + return true; + } + + void impl() { + if (simple_impl()) return; + assert(!"no implementation available"); + } + + jit_uni_reorder_kernel_f32(const desc_t &desc) + : kernel_t(desc), jit_generator() { + itype_sz = data_type_size(prb_.itype); + otype_sz = data_type_size(prb_.otype); + stype_sz = sizeof(float); + + preamble(); +# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)] + if (prb_.scale_type == scale_type_t::COMMON) { + auto reg_ptr_scale_tmp = reg_ptr_in; + mov(reg_ptr_scale_tmp, PARAM(scale)); + movups(xmm_scale, ptr[reg_ptr_scale_tmp]); + } else if (prb_.scale_type == scale_type_t::MANY) { + mov(reg_ptr_scale, PARAM(scale)); + } + mov(reg_ptr_in, PARAM(in)); + mov(reg_ptr_out, PARAM(out)); +# undef PARAM + + if (mayiuse(avx)) { + vxorps(xmm_zero, xmm_zero, xmm_zero); + + if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { + mov(reg_tmp.cvt32(), 0x7f7f7f7f); + movd(xmm_4x127b, reg_tmp.cvt32()); + } + } + + impl(); + postamble(); + ker_ = (void (*)(const call_param_t *))getCode(); + } + +private: + int itype_sz; + int otype_sz; + int stype_sz; + + Reg64 reg_ptr_in = rsi; + Reg64 reg_ptr_out = rdx; + Reg64 reg_ptr_scale = abi_not_param1; + + Reg64 reg_off_in = r8; + Reg64 reg_off_out = r9; + Reg64 reg_off_scale = r10; + + Reg64 reg_tmp = rax; + + Xmm xmm_scale = xmm15; + Xmm xmm_zero = xmm14; + Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero + Xmm xmm_tmp = xmm12; +}; + +status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb, + int ndims_ker_max) { + desc.prb = prb; + desc.prb.ioff = desc.prb.ooff = 0; + + if (ndims_ker_max > prb.ndims) + return status::invalid_arguments; + + auto ndims_ker_max_f = [&]() { + size_t cur_size = 1; + for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) + if (cur_size >= ker_prb_size_min) return d; + return prb.ndims; + }; + + if (ndims_ker_max <= 0) + ndims_ker_max = ndims_ker_max_f(); + + /* traverse through kernel implementations */ + /* TODO: find a better way to do that... */ + desc.id = 0; + for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { + desc.prb.ndims = ndims_ker; + if (jit_uni_reorder_kernel_f32::applicable(desc.prb)) + return status::success; + } + + return status::unimplemented; +} + +kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { + switch (desc.id) { + case 0: return new jit_uni_reorder_kernel_f32(desc); + default: assert(!"unknown kernel id"); return nullptr; + } + + return nullptr; +} + +} + +static void prb_block_for_cache(tr::prb_t &prb) { + if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) { + /** an attempt to use caches more efficient and + * address the 4K-aliasing issue */ + /* TODO: improve the logic around here */ + int j = 1; + for (; j < prb.ndims && prb.nodes[j].is != 1; ++j); + if (j == prb.ndims) return; + + /* it makes sense to re-prioritize sequential read over + * sequential write if the former would not trash the + * cache, i.e. is == 1 and os % 2^smth != 0. Smth is + * set to 2 at the moment */ + const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1; + if (j == move_to) return; + + if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0) + prb_node_split(prb, j, 16); + + prb_node_move(prb, j, move_to); + DEBUG({ printf("cache: "); prb_dump(prb); }); + } +} + +/** finds the maximum number of dimension the kernel should process and + * optionally splits one of the dimension to achieve better balance between + * parallel driver and the kernel. */ +static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) { + size_t sz_total = 1; + for (int d = 0; d < prb.ndims; ++d) + sz_total *= prb.nodes[d].n; + + /* sz_drv_min is the minimal size for the parallel + * driver required for good parallelization */ + const size_t sz_drv_min = nstl::min( + 16 * mkldnn_get_max_threads(), + utils::div_up(sz_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel + * sz_ker_cur -- product of the dimension processed by a kernel + * sz_drv_cur -- product of the dimension processed by a driver */ + + int kdims = prb.ndims; + size_t sz_drv_cur = 1; + for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) + sz_drv_cur *= prb.nodes[kdims - 1].n; + + size_t sz_ker_cur = 1; + for (int d = 0; d < kdims; ++d) + sz_ker_cur *= prb.nodes[d].n; + + /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. + * + * It might happen that for chosen kdims the sz_ker_cur is too small + * (less than tr::ker_prb_size_min). In that case try to split the + * innermost driver dimension into two, to increase sz_ker_cur. */ + bool want_borrow_ker_from_drv = true + && kdims < prb.ndims + && sz_ker_cur < tr::ker_prb_size_min + && sz_drv_cur > sz_drv_min; + if (want_borrow_ker_from_drv) { + /* sz_want_borrow is the minimal sz, so that: + * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min + * o) current innermost driver dimension is divisible by + * sz_want_borrow (so that we can evenly split that + * dimension into two) + * + * In the worst case the minimal sz_want_borrow is equal + * to the innermost driver dimension itself. In that case + * we will sacrifice it in favor of kernel (is it fine?). */ + size_t sz_want_borrow + = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); + for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims].n) + prb_node_split(prb, kdims, sz_want_borrow); + kdims += 1; + } + + /* On the other hand it might happen that for chosen kdims + * the sz_drv_cur is too small (less than sz_drv_min). In that case + * try to split the outermost kernel dimension into two, to increase + * sz_drv_cur. */ + bool want_borrow_drv_from_ker = true + && sz_ker_cur > tr::ker_prb_size_min + && sz_drv_cur < sz_drv_min; + if (want_borrow_drv_from_ker) { + size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); + for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims - 1].n) + prb_node_split(prb, kdims - 1, + prb.nodes[kdims - 1].n / sz_want_borrow); + } + + ndims_ker_max = kdims; + + if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { + DEBUG({ printf("split: "); prb_dump(prb); + printf("ndims_ker_max = %d\n", ndims_ker_max); }); + } +} + +struct jit_uni_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + auto prb = tr::prb_t(); + + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + + DEBUG({ printf("init : "); prb_dump(prb); }); + prb_normalize(prb); + DEBUG({ printf("norm : "); prb_dump(prb); }); + prb_simplify(prb); + DEBUG({ printf("smpl : "); prb_dump(prb); }); + + prb_block_for_cache(prb); + + int ndims_ker_max; + prb_thread_kernel_balance(prb, ndims_ker_max); + + tr::kernel_t::desc_t ker_desc; + status_t ker_init_status + = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); + if (ker_init_status != status::success) return ker_init_status; + + const int ndims_driver = prb.ndims - ker_desc.prb.ndims; + if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) + return status::unimplemented; + + DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); }); + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + _pd->prb_ = prb; + _pd->ker_desc_ = ker_desc; + return safe_ptr_assign(*reorder_pd, _pd); + } + + tr::prb_t prb_; + tr::kernel_t::desc_t ker_desc_; + }; + + jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = tr::kernel_t::create(pd()->ker_desc_); + assert(kernel_); + } + ~jit_uni_reorder_t() { delete kernel_; } + + void omp_driver_0d(int off, const char *in, char *out, + const float *scale) const { + tr::call_param_t c{in, out, scale}; + (*kernel_)(&c); + } + + void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); + c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is + + d3 * ns[3].is) * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os + + d3 * ns[3].os) * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss + + d3 * ns[3].ss; + (*kernel_)(&c); + }); + } + + void omp_driver(const char *in, char *out, const float *scale) const { + in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); + out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); + + DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); }); + DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); }); + + int ndims = pd()->prb_.ndims; + int ndims_ker = pd()->ker_desc_.prb.ndims; + assert(ndims - ndims_ker <= ndims_driver_max); + + if (ndims - ndims_ker == 0) { + omp_driver_0d(ndims_ker, in, out, scale); + } else { + parallel(0, [&](const int ithr, const int nthr) { + switch (ndims - ndims_ker) { + case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break; + case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break; + case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break; + case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break; + default: assert(!"unimplemented"); + } + }); + } + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM); + auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + + omp_driver(in, out, pd()->attr()->output_scales_.scales_); + + return status::success; + } + + enum { ndims_driver_max = 4 }; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + tr::kernel_t *kernel_; +}; + +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr, + src_engine, src_md, dst_engine, dst_md); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp new file mode 100644 index 0000000000..0746ea61d3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef _JIT_UNI_REORDER_HPP +#define _JIT_UNI_REORDER_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +constexpr int max_ndims = MKLDNN_MAX_NDIMS; + +struct node_t { + size_t n; + ptrdiff_t is; // input stride + ptrdiff_t os; // output stride + ptrdiff_t ss; // scale stride +}; + +enum class scale_type_t { NONE, COMMON, MANY }; + +struct prb_t { + data_type_t itype; + data_type_t otype; + int ndims; + node_t nodes[max_ndims]; + ptrdiff_t ioff; + ptrdiff_t ooff; + scale_type_t scale_type; + float beta; +}; + +status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + +/** sorts the problem nodes so that output strides come in ascending order */ +void prb_normalize(prb_t &p); + +/** folds nodes together if possible */ +void prb_simplify(prb_t &p); + +/** splits the node dim into two of sizes n1 and n / n1 + * @warning n must be multiple of n1 */ +void prb_node_split(prb_t &p, int dim, size_t n1); + +/** swaps d0 and d1 nodes */ +void prb_node_swap(prb_t &p, int d0, int d1); + +/** moves node d0 to the d1 position. + * nodes (d0, d1] are shifted to the left if d0 < d1 or + * to the right if d0 > d1 */ +void prb_node_move(prb_t &p, int d0, int d1); + +/** dumps the problem to stdout */ +void prb_dump(const prb_t &p); + +struct call_param_t { + const void *in; + void *out; + const float *scale; +}; + +struct kernel_t { + struct desc_t { + int id; + prb_t prb; + }; + + kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {} + void operator()(const call_param_t *c) const { assert(ker_); ker_(c); } + virtual ~kernel_t() {} + + /** inits kernel descriptor: + * desc -- kernel descriptor (output) + * prb -- transposition problem (input) + * ndims_ker_max -- limit the maximum number of dimensions kernel + * will process (optional, 0 -- no limitation) */ + static status_t desc_init(desc_t &desc, const prb_t &prb, + int ndims_ker_max = 0); + + /** creates kernel for the problem described in desc */ + static kernel_t *create(const desc_t &desc); + +protected: + const desc_t desc_; + const prb_t &prb_ = desc_.prb; + void (*ker_)(const call_param_t *); +}; + +/* TODO: add trans_t class */ + +} + +/* for cpu reorder list */ +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp new file mode 100644 index 0000000000..69b7a33604 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp @@ -0,0 +1,313 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_reorder.hpp" + +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::status; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** ad-hoc structure to describe blocked memory layout */ +struct layout_desc_t { + data_type_t dt; + int ndims; + dims_t id; + dims_t dims; + strides_t strides; +}; + +status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + layout_desc_t &ld) { + const auto md = memory_desc_wrapper(md_); + + bool ok = true + && md.is_blocking_desc() + && md.extra().flags == 0; + if (!ok) return invalid_arguments; + + const auto &bd = md.blocking_desc(); + + ld.ndims = 0; + ld.dt = md.data_type(); + + auto P = [&ld](int id, int dim, ptrdiff_t stride) { + assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); + ld.id[ld.ndims] = id; + ld.dims[ld.ndims] = dim; + ld.strides[ld.ndims] = stride; + ++ld.ndims; + }; + + dims_t blocks; + md.compute_blocks(blocks); + + for (int d = 0; d < md.ndims(); ++d) { + const int ld_ndims_start = ld.ndims; + if (blocks[d] != 1) { + stride_t stride = 1; + for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { + if (bd.inner_idxs[iblk] == d) + P(d, bd.inner_blks[iblk], stride); + stride *= bd.inner_blks[iblk]; + } + } + P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); + + // TODO: NOW: revisit, do we need a reverse? + // TODO: NOW: consider using strides instead of block sizes in md + // reverse the order of dims + for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { + const int idx0 = ld_ndims_start + ld_d; + const int idx1 = ld.ndims - 1 - ld_d; + nstl::swap(ld.dims[idx0], ld.dims[idx1]); + nstl::swap(ld.strides[idx0], ld.strides[idx1]); + } + } + + return success; +} + +status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + const primitive_attr_t *attr) { + auto im_d = memory_desc_wrapper(imd); + auto om_d = memory_desc_wrapper(omd); + + bool ok = true + && im_d.is_blocking_desc() + && om_d.is_blocking_desc() + && !im_d.has_zero_dim() + && !om_d.has_zero_dim(); + if (!ok) + return unimplemented; + + dims_t iblocks, oblocks; + im_d.compute_blocks(iblocks); + om_d.compute_blocks(oblocks); + + /* padding_dim consistency check */ + for (int d = 0; d < im_d.ndims(); ++d) { + const auto pdim = im_d.padded_dims()[d]; + bool ok = true + && pdim == om_d.padded_dims()[d] + && pdim % iblocks[d] == 0 + && pdim % oblocks[d] == 0; + if (!ok) return unimplemented; + } + + layout_desc_t ild, old; + status_t status = cvt_mem_desc_to_layout_desc(imd, ild); + if (status != success) return status; + status = cvt_mem_desc_to_layout_desc(omd, old); + if (status != success) return status; + + p.itype = ild.dt; + p.otype = old.dt; + + p.scale_type = attr->output_scales_.has_default_values() + ? scale_type_t::NONE + : (attr->output_scales_.mask_ == 0 + ? scale_type_t::COMMON + : scale_type_t::MANY); + + ptrdiff_t ss[max_ndims] = {0}; + if (p.scale_type == scale_type_t::MANY) { + ptrdiff_t last_ss = 1; + for (int d = old.ndims - 1; d >=0; --d) { + assert((d == 0 || old.id[d - 1] <= old.id[d]) + && "logical dimensions should be in ascending order"); + if (attr->output_scales_.mask_ & (1 << old.id[d])) { + ss[d] = last_ss; + last_ss *= old.dims[d]; + } + } + } + + int ndims = 0; + + int i_pos = 0; /* state for input -- current dimension */ + int o_pos = 0; /* state for output -- current dimension */ + + while (i_pos < ild.ndims && o_pos < old.ndims) { + assert(ild.id[i_pos] == old.id[o_pos]); + if (ild.id[i_pos] != old.id[o_pos]) + return runtime_error; + + assert(ndims < max_ndims); + if (ndims == max_ndims) + return runtime_error; + + if (ild.dims[i_pos] == old.dims[o_pos]) { + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++i_pos; + ++o_pos; + } else if (ild.dims[i_pos] < old.dims[o_pos]) { + assert(old.dims[o_pos] % ild.dims[i_pos] == 0); + int factor = old.dims[o_pos] / ild.dims[i_pos]; + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos] * factor; + p.nodes[ndims].ss = ss[o_pos] * factor; + ++ndims; + ++i_pos; + old.dims[o_pos] = factor; + } else if (ild.dims[i_pos] > old.dims[o_pos]) { + assert(ild.dims[i_pos] % old.dims[o_pos] == 0); + int factor = ild.dims[i_pos] / old.dims[o_pos]; + p.nodes[ndims].n = old.dims[o_pos]; + p.nodes[ndims].is = ild.strides[i_pos] * factor; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++o_pos; + ild.dims[i_pos] = factor; + } + } + p.ndims = ndims; + + dims_t zero_pos = {0}; + p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); + p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); + + const int sum_idx = attr->post_ops_.find(primitive_kind::sum); + p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; + + return success; +} + +void prb_normalize(prb_t &p) { + for (int d = 0; d < p.ndims; ++d) { + int min_pos = d; + for (int j = d + 1; j < p.ndims; ++j) { + bool new_min = false + || p.nodes[j].os < p.nodes[min_pos].os + || (true + && p.nodes[j].os == p.nodes[min_pos].os + && p.nodes[j].n < p.nodes[min_pos].n); + if (new_min) min_pos = j; + } + if (min_pos != d) + nstl::swap(p.nodes[d], p.nodes[min_pos]); + } +} + +void prb_simplify(prb_t &p) { +#if defined(__GNUC__) && __GNUC__ >= 4 +/* GCC produces bogus array subscript is above array bounds warning for + * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif + for (int d = 0; d < p.ndims - 1; ++d) { + auto &this_node = p.nodes[d + 0]; + auto &next_node = p.nodes[d + 1]; + const bool fold = false + || next_node.n == (size_t)1 // trivial case, just drop next node + || (true // or real folding if possible + && next_node.is == (ptrdiff_t)this_node.n * this_node.is + && next_node.os == (ptrdiff_t)this_node.n * this_node.os + && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); + if (fold) { + this_node.n *= next_node.n; + for (int j = d + 2; j < p.ndims; ++j) + p.nodes[j - 1] = p.nodes[j]; + --p.ndims; + --d; // make another try + } + } +#if defined(__GNUC__) && __GNUC__ >= 4 +#pragma GCC diagnostic pop +#endif +} + +void prb_node_split(prb_t &p, int dim, size_t n1) { + assert(dim < p.ndims); + assert(p.ndims < max_ndims); + assert(p.nodes[dim].n % n1 == 0); + + p.ndims += 1; + + for (int d = p.ndims; d > dim + 1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[dim + 1].n = p.nodes[dim].n / n1; + p.nodes[dim + 1].is = p.nodes[dim].is * n1; + p.nodes[dim + 1].os = p.nodes[dim].os * n1; + p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; + + p.nodes[dim].n = n1; +} + +void prb_node_swap(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + nstl::swap(p.nodes[d0], p.nodes[d1]); +} + +void prb_node_move(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + node_t node = p.nodes[d0]; + + if (d0 < d1) + for (int d = d0; d < d1; ++d) + p.nodes[d] = p.nodes[d + 1]; + else + for (int d = d0; d > d1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[d1] = node; +} + +void prb_dump(const prb_t &p) { + printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype), + mkldnn_dt2str(p.otype), p.ndims); + for (int d = 0; d < p.ndims; ++d) + printf("[%zu:%td:%td:%td]", + p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); + printf(" off:%zu:%zu\n", p.ioff, p.ooff); +} + +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp new file mode 100644 index 0000000000..08747aa89c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "utils.hpp" + +#ifndef MKLDNN_ENABLE_JIT_PROFILING +#define MKLDNN_ENABLE_JIT_PROFILING 1 +#endif + +#ifndef MKLDNN_ENABLE_JIT_DUMP +#define MKLDNN_ENABLE_JIT_DUMP 1 +#endif + +#if MKLDNN_ENABLE_JIT_PROFILING +#include "jitprofiling/jitprofiling.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +// WARNING: These functions are not thread safe and must be protected by a +// mutex + +void dump_jit_code(const void *code, size_t code_size, const char *code_name) +{ +#if MKLDNN_ENABLE_JIT_DUMP + if (code && jit_dump_enabled()) { + static int counter = 0; +#define MAX_FNAME_LEN 256 + char fname[MAX_FNAME_LEN + 1]; + // TODO (Roma): support prefix for code / linux perf dumps + snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name, + counter); + counter++; + + FILE *fp = fopen(fname, "w+"); + // Failure to dump code is not fatal + if (fp) { + size_t unused = fwrite(code, code_size, 1, fp); + UNUSED(unused); + fclose(fp); + } + } +#undef MAX_FNAME_LEN +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); +#endif +} + +void register_jit_code_vtune(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ +#if MKLDNN_ENABLE_JIT_PROFILING + if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) { + auto jmethod = iJIT_Method_Load(); + jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe + jmethod.method_name = (char *)code_name; // XXX: dropping const + jmethod.class_file_name = NULL; + jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const + jmethod.method_load_address = (void *)code; + jmethod.method_size = (unsigned int)code_size; + + iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, + (void*)&jmethod); + } +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ + // The #ifdef guards are required to avoid generating a function that only + // consists of lock and unlock code +#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP + static std::mutex m; + std::lock_guard guard(m); + + dump_jit_code(code, code_size, code_name); + register_jit_code_vtune(code, code_size, code_name, source_file_name); +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp new file mode 100644 index 0000000000..2f52dba4ac --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SUPPORT_HPP +#define JIT_SUPPORT_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name); + +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD new file mode 100644 index 0000000000..4fd21cea57 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD @@ -0,0 +1,27 @@ +Copyright (c) 2011, Intel Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md new file mode 100644 index 0000000000..fc67c4f134 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md @@ -0,0 +1 @@ +This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI) diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h new file mode 100644 index 0000000000..edbf4a15f0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h @@ -0,0 +1,595 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef _ITTNOTIFY_CONFIG_H_ +#define _ITTNOTIFY_CONFIG_H_ + +/** @cond exclude_from_documentation */ +#ifndef ITT_OS_WIN +# define ITT_OS_WIN 1 +#endif /* ITT_OS_WIN */ + +#ifndef ITT_OS_LINUX +# define ITT_OS_LINUX 2 +#endif /* ITT_OS_LINUX */ + +#ifndef ITT_OS_MAC +# define ITT_OS_MAC 3 +#endif /* ITT_OS_MAC */ + +#ifndef ITT_OS_FREEBSD +# define ITT_OS_FREEBSD 4 +#endif /* ITT_OS_FREEBSD */ + +#ifndef ITT_OS +# if defined WIN32 || defined _WIN32 +# define ITT_OS ITT_OS_WIN +# elif defined( __APPLE__ ) && defined( __MACH__ ) +# define ITT_OS ITT_OS_MAC +# elif defined( __FreeBSD__ ) +# define ITT_OS ITT_OS_FREEBSD +# else +# define ITT_OS ITT_OS_LINUX +# endif +#endif /* ITT_OS */ + +#ifndef ITT_PLATFORM_WIN +# define ITT_PLATFORM_WIN 1 +#endif /* ITT_PLATFORM_WIN */ + +#ifndef ITT_PLATFORM_POSIX +# define ITT_PLATFORM_POSIX 2 +#endif /* ITT_PLATFORM_POSIX */ + +#ifndef ITT_PLATFORM_MAC +# define ITT_PLATFORM_MAC 3 +#endif /* ITT_PLATFORM_MAC */ + +#ifndef ITT_PLATFORM_FREEBSD +# define ITT_PLATFORM_FREEBSD 4 +#endif /* ITT_PLATFORM_FREEBSD */ + +#ifndef ITT_PLATFORM +# if ITT_OS==ITT_OS_WIN +# define ITT_PLATFORM ITT_PLATFORM_WIN +# elif ITT_OS==ITT_OS_MAC +# define ITT_PLATFORM ITT_PLATFORM_MAC +# elif ITT_OS==ITT_OS_FREEBSD +# define ITT_PLATFORM ITT_PLATFORM_FREEBSD +# else +# define ITT_PLATFORM ITT_PLATFORM_POSIX +# endif +#endif /* ITT_PLATFORM */ + +#if defined(_UNICODE) && !defined(UNICODE) +#define UNICODE +#endif + +#include +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE || _UNICODE */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef ITTAPI_CDECL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define ITTAPI_CDECL __cdecl +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define ITTAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define ITTAPI_CDECL /* actual only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* ITTAPI_CDECL */ + +#ifndef STDCALL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define STDCALL __stdcall +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define STDCALL __attribute__ ((stdcall)) +# else /* _M_IX86 || __i386__ */ +# define STDCALL /* supported only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* STDCALL */ + +#define ITTAPI ITTAPI_CDECL +#define LIBITTAPI ITTAPI_CDECL + +/* TODO: Temporary for compatibility! */ +#define ITTAPI_CALL ITTAPI_CDECL +#define LIBITTAPI_CALL ITTAPI_CDECL + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +/* use __forceinline (VC++ specific) */ +#define ITT_INLINE __forceinline +#define ITT_INLINE_ATTRIBUTE /* nothing */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/* + * Generally, functions are not inlined unless optimization is specified. + * For functions declared inline, this attribute inlines the function even + * if no optimization level was specified. + */ +#ifdef __STRICT_ANSI__ +#define ITT_INLINE static +#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) +#else /* __STRICT_ANSI__ */ +#define ITT_INLINE static inline +#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) +#endif /* __STRICT_ANSI__ */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/** @endcond */ + +#ifndef ITT_ARCH_IA32 +# define ITT_ARCH_IA32 1 +#endif /* ITT_ARCH_IA32 */ + +#ifndef ITT_ARCH_IA32E +# define ITT_ARCH_IA32E 2 +#endif /* ITT_ARCH_IA32E */ + +#ifndef ITT_ARCH_ARM +# define ITT_ARCH_ARM 4 +#endif /* ITT_ARCH_ARM */ + +#ifndef ITT_ARCH_PPC64 +# define ITT_ARCH_PPC64 5 +#endif /* ITT_ARCH_PPC64 */ + +#ifndef ITT_ARCH +# if defined _M_IX86 || defined __i386__ +# define ITT_ARCH ITT_ARCH_IA32 +# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__ +# define ITT_ARCH ITT_ARCH_IA32E +# elif defined _M_IA64 || defined __ia64__ +# define ITT_ARCH ITT_ARCH_IA64 +# elif defined _M_ARM || defined __arm__ +# define ITT_ARCH ITT_ARCH_ARM +# elif defined __powerpc64__ +# define ITT_ARCH ITT_ARCH_PPC64 +# endif +#endif + +#ifdef __cplusplus +# define ITT_EXTERN_C extern "C" +# define ITT_EXTERN_C_BEGIN extern "C" { +# define ITT_EXTERN_C_END } +#else +# define ITT_EXTERN_C /* nothing */ +# define ITT_EXTERN_C_BEGIN /* nothing */ +# define ITT_EXTERN_C_END /* nothing */ +#endif /* __cplusplus */ + +#define ITT_TO_STR_AUX(x) #x +#define ITT_TO_STR(x) ITT_TO_STR_AUX(x) + +#define __ITT_BUILD_ASSERT(expr, suffix) do { \ + static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \ + __itt_build_check_##suffix[0] = 0; \ +} while(0) +#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix) +#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__) + +#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 } + +/* Replace with snapshot date YYYYMMDD for promotion build. */ +#define API_VERSION_BUILD 20151119 + +#ifndef API_VERSION_NUM +#define API_VERSION_NUM 0.0.0 +#endif /* API_VERSION_NUM */ + +#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \ + " (" ITT_TO_STR(API_VERSION_BUILD) ")" + +/* OS communication functions */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +typedef HMODULE lib_t; +typedef DWORD TIDT; +typedef CRITICAL_SECTION mutex_t; +#define MUTEX_INITIALIZER { 0 } +#define strong_alias(name, aliasname) /* empty for Windows */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */ +#endif /* _GNU_SOURCE */ +#ifndef __USE_UNIX98 +#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */ +#endif /*__USE_UNIX98*/ +#include +typedef void* lib_t; +typedef pthread_t TIDT; +typedef pthread_mutex_t mutex_t; +#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER +#define _strong_alias(name, aliasname) \ + extern __typeof (name) aliasname __attribute__ ((alias (#name))); +#define strong_alias(name, aliasname) _strong_alias(name, aliasname) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_get_proc(lib, name) GetProcAddress(lib, name) +#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex) +#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex) +#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex) +#define __itt_load_lib(name) LoadLibraryA(name) +#define __itt_unload_lib(handle) FreeLibrary(handle) +#define __itt_system_error() (int)GetLastError() +#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2) +#define __itt_fstrnlen(s, l) strnlen_s(s, l) +#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l) +#define __itt_fstrdup(s) _strdup(s) +#define __itt_thread_id() GetCurrentThreadId() +#define __itt_thread_yield() SwitchToThread() +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return InterlockedIncrement(ptr); +} +#endif /* ITT_SIMPLE_INIT */ + +#define DL_SYMBOLS (1) +#define PTHREAD_SYMBOLS (1) + +#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */ +#define __itt_get_proc(lib, name) dlsym(lib, name) +#define __itt_mutex_init(mutex) {\ + pthread_mutexattr_t mutex_attr; \ + int error_code = pthread_mutexattr_init(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \ + error_code); \ + error_code = pthread_mutexattr_settype(&mutex_attr, \ + PTHREAD_MUTEX_RECURSIVE); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \ + error_code); \ + error_code = pthread_mutex_init(mutex, &mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutex_init", \ + error_code); \ + error_code = pthread_mutexattr_destroy(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \ + error_code); \ +} +#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex) +#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex) +#define __itt_load_lib(name) dlopen(name, RTLD_LAZY) +#define __itt_unload_lib(handle) dlclose(handle) +#define __itt_system_error() errno +#define __itt_fstrcmp(s1, s2) strcmp(s1, s2) + +/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */ +#ifdef SDL_STRNLEN_S +#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l) +#else +#define __itt_fstrnlen(s, l) strlen(s) +#endif /* SDL_STRNLEN_S */ +#ifdef SDL_STRNCPY_S +#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l) +#else +#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l) +#endif /* SDL_STRNCPY_S */ + +#define __itt_fstrdup(s) strdup(s) +#define __itt_thread_id() pthread_self() +#define __itt_thread_yield() sched_yield() +#if ITT_ARCH==ITT_ARCH_IA64 +#ifdef __INTEL_COMPILER +#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val) +#else /* __INTEL_COMPILER */ +/* TODO: Add Support for not Intel compilers for IA-64 architecture */ +#endif /* __INTEL_COMPILER */ +#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */ +ITT_INLINE long +__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend) +{ + long result; + __asm__ __volatile__("lock\nxadd %0,%1" + : "=r"(result),"=m"(*(int*)ptr) + : "0"(addend), "m"(*(int*)ptr) + : "memory"); + return result; +} +#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64 +#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val) +#endif /* ITT_ARCH==ITT_ARCH_IA64 */ +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return __TBB_machine_fetchadd4(ptr, 1) + 1L; +} +#endif /* ITT_SIMPLE_INIT */ + +void* dlopen(const char*, int) __attribute__((weak)); +void* dlsym(void*, const char*) __attribute__((weak)); +int dlclose(void*) __attribute__((weak)); +#define DL_SYMBOLS (dlopen && dlsym && dlclose) + +int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak)); +int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak)); +pthread_t pthread_self(void) __attribute__((weak)); +#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self) + +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +typedef enum { + __itt_collection_normal = 0, + __itt_collection_paused = 1 +} __itt_collection_state; + +typedef enum { + __itt_thread_normal = 0, + __itt_thread_ignored = 1 +} __itt_thread_state; + +#pragma pack(push, 8) + +typedef struct ___itt_thread_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + TIDT tid; + __itt_thread_state state; /*!< Thread state (paused or normal) */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_thread_info* next; +} __itt_thread_info; + +#include "ittnotify_types.h" /* For __itt_group_id definition */ + +typedef struct ___itt_api_info_20101001 +{ + const char* name; + void** func_ptr; + void* init_func; + __itt_group_id group; +} __itt_api_info_20101001; + +typedef struct ___itt_api_info +{ + const char* name; + void** func_ptr; + void* init_func; + void* null_func; + __itt_group_id group; +} __itt_api_info; + +typedef struct __itt_counter_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + const char* domainA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* domainW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* domainW; +#endif /* UNICODE || _UNICODE */ + int type; + long index; + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct __itt_counter_info* next; +} __itt_counter_info_t; + +struct ___itt_domain; +struct ___itt_string_handle; + +typedef struct ___itt_global +{ + unsigned char magic[8]; + unsigned long version_major; + unsigned long version_minor; + unsigned long version_build; + volatile long api_initialized; + volatile long mutex_initialized; + volatile long atomic_counter; + mutex_t mutex; + lib_t lib; + void* error_handler; + const char** dll_path_ptr; + __itt_api_info* api_list_ptr; + struct ___itt_global* next; + /* Joinable structures below */ + __itt_thread_info* thread_list; + struct ___itt_domain* domain_list; + struct ___itt_string_handle* string_list; + __itt_collection_state state; + __itt_counter_info_t* counter_list; +} __itt_global; + +#pragma pack(pop) + +#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = NULL; \ + h->nameW = n ? _wcsdup(n) : NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = n ? __itt_fstrdup(n) : NULL; \ + h->nameW = NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = NULL; \ + h->strW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = name ? __itt_fstrdup(name) : NULL; \ + h->strW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->domainA = NULL; \ + h->domainW = name ? _wcsdup(domain) : NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->domainA = domain ? __itt_fstrdup(domain) : NULL; \ + h->domainW = NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#endif /* _ITTNOTIFY_CONFIG_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h new file mode 100644 index 0000000000..99fbc24054 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h @@ -0,0 +1,94 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _ITTNOTIFY_TYPES_H_ +#define _ITTNOTIFY_TYPES_H_ + +typedef enum ___itt_group_id +{ + __itt_group_none = 0, + __itt_group_legacy = 1<<0, + __itt_group_control = 1<<1, + __itt_group_thread = 1<<2, + __itt_group_mark = 1<<3, + __itt_group_sync = 1<<4, + __itt_group_fsync = 1<<5, + __itt_group_jit = 1<<6, + __itt_group_model = 1<<7, + __itt_group_splitter_min = 1<<7, + __itt_group_counter = 1<<8, + __itt_group_frame = 1<<9, + __itt_group_stitch = 1<<10, + __itt_group_heap = 1<<11, + __itt_group_splitter_max = 1<<12, + __itt_group_structure = 1<<12, + __itt_group_suppress = 1<<13, + __itt_group_arrays = 1<<14, + __itt_group_all = -1 +} __itt_group_id; + +#pragma pack(push, 8) + +typedef struct ___itt_group_list +{ + __itt_group_id id; + const char* name; +} __itt_group_list; + +#pragma pack(pop) + +#define ITT_GROUP_LIST(varname) \ + static __itt_group_list varname[] = { \ + { __itt_group_all, "all" }, \ + { __itt_group_control, "control" }, \ + { __itt_group_thread, "thread" }, \ + { __itt_group_mark, "mark" }, \ + { __itt_group_sync, "sync" }, \ + { __itt_group_fsync, "fsync" }, \ + { __itt_group_jit, "jit" }, \ + { __itt_group_model, "model" }, \ + { __itt_group_counter, "counter" }, \ + { __itt_group_frame, "frame" }, \ + { __itt_group_stitch, "stitch" }, \ + { __itt_group_heap, "heap" }, \ + { __itt_group_structure, "structure" }, \ + { __itt_group_suppress, "suppress" }, \ + { __itt_group_arrays, "arrays" }, \ + { __itt_group_none, NULL } \ + } + +#endif /* _ITTNOTIFY_TYPES_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c new file mode 100644 index 0000000000..15f4b9929b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c @@ -0,0 +1,293 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "ittnotify_config.h" + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD +#include +#endif +#include + +#include "jitprofiling.h" + +static const char rcsid[] = "\n@(#) $Revision: 471937 $\n"; + +#define DLL_ENVIRONMENT_VAR "VS_PROFILER" + +#ifndef NEW_DLL_ENVIRONMENT_VAR +#if ITT_ARCH==ITT_ARCH_IA32 +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32" +#else +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64" +#endif +#endif /* NEW_DLL_ENVIRONMENT_VAR */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define DEFAULT_DLLNAME "JitPI.dll" +HINSTANCE m_libHandle = NULL; +#elif ITT_PLATFORM==ITT_PLATFORM_MAC +#define DEFAULT_DLLNAME "libJitPI.dylib" +void* m_libHandle = NULL; +#else +#define DEFAULT_DLLNAME "libJitPI.so" +void* m_libHandle = NULL; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/* default location of JIT profiling agent on Android */ +#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so" + +/* the function pointers */ +typedef unsigned int(JITAPI *TPInitialize)(void); +static TPInitialize FUNC_Initialize=NULL; + +typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*); +static TPNotify FUNC_NotifyEvent=NULL; + +static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING; + +/* end collector dll part. */ + +/* loadiJIT_Funcs() : this function is called just in the beginning + * and is responsible to load the functions from BistroJavaCollector.dll + * result: + * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1 + * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0 + */ +static int loadiJIT_Funcs(void); + +/* global representing whether the collector can't be loaded */ +static int iJIT_DLL_is_missing = 0; + +ITT_EXTERN_C int JITAPI +iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData) +{ + int ReturnValue = 0; + + /* initialization part - the collector has not been loaded yet. */ + if (!FUNC_NotifyEvent) + { + if (iJIT_DLL_is_missing) + return 0; + + if (!loadiJIT_Funcs()) + return 0; + } + + if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED || + event_type == iJVM_EVENT_TYPE_METHOD_UPDATE) + { + if (((piJIT_Method_Load)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2) + { + if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3) + { + if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED) + { + if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 || + ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0) + return 0; + } + + ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData); + + return ReturnValue; +} + +ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive() +{ + if (!iJIT_DLL_is_missing) + { + loadiJIT_Funcs(); + } + + return executionMode; +} + +/* This function loads the collector dll and the relevant functions. + * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1 + * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0 + */ +static int loadiJIT_Funcs() +{ + static int bDllWasLoaded = 0; + char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + DWORD dNameLength = 0; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if(bDllWasLoaded) + { + /* dll was already loaded, no need to do it for the second time */ + return 1; + } + + /* Assumes that the DLL will not be found */ + iJIT_DLL_is_missing = 1; + FUNC_NotifyEvent = NULL; + + if (m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FreeLibrary(m_libHandle); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dlclose(m_libHandle); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = NULL; + } + + /* Try to get the dll name from the environment */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryExA(dllName, + NULL, LOAD_WITH_ALTERED_SEARCH_PATH); + } + free(dllName); + } + } else { + /* Try to use old VS_PROFILER variable */ + dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryA(dllName); + } + free(dllName); + } + } + } +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dllName = getenv(NEW_DLL_ENVIRONMENT_VAR); + if (!dllName) + dllName = getenv(DLL_ENVIRONMENT_VAR); +#if defined(__ANDROID__) || defined(ANDROID) + if (!dllName) + dllName = ANDROID_JIT_AGENT_PATH; +#endif + if (dllName) + { + /* Try to load the dll from the PATH... */ + m_libHandle = dlopen(dllName, RTLD_LAZY); + } +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if (!m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + m_libHandle = LoadLibraryA(DEFAULT_DLLNAME); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + } + + /* if the dll wasn't loaded - exit. */ + if (!m_libHandle) + { + iJIT_DLL_is_missing = 1; /* don't try to initialize + * JIT agent the second time + */ + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_NotifyEvent) + { + FUNC_Initialize = NULL; + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_Initialize) + { + FUNC_NotifyEvent = NULL; + return 0; + } + + executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize(); + + bDllWasLoaded = 1; + iJIT_DLL_is_missing = 0; /* DLL is ok. */ + + return 1; +} + +ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID() +{ + static unsigned int methodID = 1; + + if (methodID == 0) + return 0; /* ERROR : this is not a valid value */ + + return methodID++; +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h new file mode 100644 index 0000000000..bf0489b1a1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h @@ -0,0 +1,673 @@ +/* + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __JITPROFILING_H__ +#define __JITPROFILING_H__ + +/** + * @brief JIT Profiling APIs + * + * The JIT Profiling API is used to report information about just-in-time + * generated code that can be used by performance tools. The user inserts + * calls in the code generator to report information before JIT-compiled + * code goes to execution. This information is collected at runtime and used + * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics + * associated with JIT-compiled code. + * + * These APIs can be used to\n + * - **Profile trace-based and method-based JIT-compiled + * code**. Some examples of environments that you can profile with these APIs: + * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM) + * software technology, Java/.NET managed execution environments, and custom + * ISV JIT engines. + * @code + * #include + * + * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) { + * return; + * } + * + * iJIT_Method_Load jmethod = {0}; + * jmethod.method_id = iJIT_GetNewMethodID(); + * jmethod.method_name = "method_name"; + * jmethod.class_file_name = "class_name"; + * jmethod.source_file_name = "source_file_name"; + * jmethod.method_load_address = code_addr; + * jmethod.method_size = code_size; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL); + * @endcode + * + * * Expected behavior: + * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and its + * memory region is treated as unloaded. VTune Amplifier displays the metrics + * collected by the method until it is overwritten. + * * If supplied line number information contains multiple source lines for + * the same assembly instruction (code location), then VTune Amplifier picks up + * the first line number. + * * Dynamically generated code can be associated with a module name. + * Use the iJIT_Method_Load_V2 structure.\n + * Clarification of some cases: + * * If you register a function with the same method ID multiple times, + * specifying different module names, then the VTune Amplifier picks up + * the module name registered first. If you want to distinguish the same + * function between different JIT engines, supply different method IDs for + * each function. Other symbolic information (for example, source file) + * can be identical. + * + * - **Analyze split functions** (multiple joint or disjoint code regions + * belonging to the same function) **including re-JIT** + * with potential overlapping of code regions in time, which is common in + * resource-limited environments. + * @code + * #include + * + * unsigned int method_id = iJIT_GetNewMethodID(); + * + * iJIT_Method_Load a = {0}; + * a.method_id = method_id; + * a.method_load_address = 0x100; + * a.method_size = 0x20; + * + * iJIT_Method_Load b = {0}; + * b.method_id = method_id; + * b.method_load_address = 0x200; + * b.method_size = 0x30; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b); + * @endcode + * + * * Expected behaviour: + * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and + * its memory region is treated as unloaded. + * * All code regions reported with the same method ID are considered as + * belonging to the same method. Symbolic information (method name, + * source file name) will be taken from the first notification, and all + * subsequent notifications with the same method ID will be processed + * only for line number table information. So, the VTune Amplifier will map + * samples to a source line using the line number table from the current + * notification while taking the source file name from the very first one.\n + * Clarification of some cases:\n + * * If you register a second code region with a different source file + * name and the same method ID, then this information will be saved and + * will not be considered as an extension of the first code region, but + * VTune Amplifier will use the source file of the first code region and map + * performance metrics incorrectly. + * * If you register a second code region with the same source file as + * for the first region and the same method ID, then the source file will be + * discarded but VTune Amplifier will map metrics to the source file correctly. + * * If you register a second code region with a null source file and + * the same method ID, then provided line number info will be associated + * with the source file of the first code region. + * + * - **Explore inline functions** including multi-level hierarchy of + * nested inline methods which shows how performance metrics are distributed through them. + * @code + * #include + * + * // method_id parent_id + * // [-- c --] 3000 2000 + * // [---- d -----] 2001 1000 + * // [---- b ----] 2000 1000 + * // [------------ a ----------------] 1000 n/a + * + * iJIT_Method_Load a = {0}; + * a.method_id = 1000; + * + * iJIT_Method_Inline_Load b = {0}; + * b.method_id = 2000; + * b.parent_method_id = 1000; + * + * iJIT_Method_Inline_Load c = {0}; + * c.method_id = 3000; + * c.parent_method_id = 2000; + * + * iJIT_Method_Inline_Load d = {0}; + * d.method_id = 2001; + * d.parent_method_id = 1000; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d); + * @endcode + * + * * Requirements: + * * Each inline (iJIT_Method_Inline_Load) method should be associated + * with two method IDs: one for itself; one for its immediate parent. + * * Address regions of inline methods of the same parent method cannot + * overlap each other. + * * Execution of the parent method must not be started until it and all + * its inline methods are reported. + * * Expected behaviour: + * * In case of nested inline methods an order of + * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important. + * * If any event overwrites either inline method or top parent method, + * then the parent, including inline methods, becomes invalid and its memory + * region is treated as unloaded. + * + * **Life time of allocated data**\n + * The client sends an event notification to the agent with event-specific + * data, which is a structure. The pointers in the structure refer to memory + * allocated by the client, which responsible for releasing it. The pointers are + * used by the iJIT_NotifyEvent method to copy client's data in a trace file, + * and they are not used after the iJIT_NotifyEvent method returns. + */ + +/** + * @defgroup jitapi JIT Profiling + * @ingroup internal + * @{ + */ + +/** + * @brief Enumerator for the types of notifications + */ +typedef enum iJIT_jvm_event +{ + iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent. + * Use NULL for event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load as event + * data. */ +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic + * code is being unloaded from memory. + * Use iJIT_Method_Load as event data.*/ +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for + * a previously reported dynamic code. + * The previous content will be invalidated + * starting from the time of the notification. + * Use iJIT_Method_Load as event data but + * required fields are following: + * - method_id identify the code to update. + * - method_load_address specify start address + * within identified code range + * where update should be started. + * - method_size specify length of updated code + * range. */ + + + iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic + * code is JIT compiled and loaded + * into memory by the JIT engine, + * but before the parent code region + * starts executing. + * Use iJIT_Method_Inline_Load as event data.*/ + +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UPDATE_V2, +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V2 as event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V3 as event data. */ +} iJIT_JVM_EVENT; + +/** + * @brief Enumerator for the agent's mode + */ +typedef enum _iJIT_IsProfilingActiveFlags +{ + iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running; + * iJIT_NotifyEvent calls will + * not be processed. */ + iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and + * ready to process notifications. */ +} iJIT_IsProfilingActiveFlags; + +/** + * @brief Description of a single entry in the line number information of a code region. + * @details A table of line number entries gives information about how the reported code region + * is mapped to source file. + * Intel(R) VTune(TM) Amplifier uses line number information to attribute + * the samples (virtual address) to a line number. \n + * It is acceptable to report different code addresses for the same source line: + * @code + * Offset LineNumber + * 1 2 + * 12 4 + * 15 2 + * 18 1 + * 21 30 + * + * VTune Amplifier constructs the following table using the client data + * + * Code subrange Line number + * 0-1 2 + * 1-12 4 + * 12-15 2 + * 15-18 1 + * 18-21 30 + * @endcode + */ +typedef struct _LineNumberInfo +{ + unsigned int Offset; /**<\brief Offset from the begining of the code region. */ + unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */ + +} *pLineNumberInfo, LineNumberInfo; + +/** + * @brief Enumerator for the code architecture. + */ +typedef enum _iJIT_CodeArchitecture +{ + iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */ + + iJIT_CA_32, /**<\brief 32-bit machine code. */ + + iJIT_CA_64 /**<\brief 64-bit machine code. */ + +} iJIT_CodeArchitecture; + +#pragma pack(push, 8) + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, data provided with + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table.0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + unsigned int class_id; /**<\brief This field is obsolete. */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Load, iJIT_Method_Load; + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load_V2 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V2 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + The module name can be useful for distinguishing among + different JIT engines. VTune Amplifier will display + reported methods grouped by specific module. */ + +} *piJIT_Method_Load_V2, iJIT_Method_Load_V2; + +/** + * @brief Description of a JIT-compiled method + * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2 + * with a newly introduced 'arch' field that specifies architecture of the code region. + * When you use the iJIT_Method_Load_V3 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V3 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise they are + * treated as regions of different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Cannot be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + * The module name can be useful for distinguishing among + * different JIT engines. VTune Amplifier will display + * reported methods grouped by specific module. */ + + iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region. + * By default, it is the same as the process + * architecture that is calling it. + * For example, you can use it if your 32-bit JIT + * engine generates 64-bit code. + * + * If JIT engine reports both 32-bit and 64-bit types + * of methods then VTune Amplifier splits the methods + * with the same module name but with different + * architectures in two different modules. VTune Amplifier + * modifies the original name provided with a 64-bit method + * version by ending it with '(64)' */ + +} *piJIT_Method_Load_V3, iJIT_Method_Load_V3; + +/** + * @brief Description of an inline JIT-compiled method + * @details When you use the_iJIT_Method_Inline_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Inline_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID. + * Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /** <\brief The virtual address on which the method + * is inlined. If NULL, then data provided with + * the event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load; + +/** @cond exclude_from_documentation */ +/** + * @brief Description of a segment type + * @details Use the segment type to specify a type of data supplied + * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to + * a certain code trace. + */ +typedef enum _iJIT_SegmentType +{ + iJIT_CT_UNKNOWN = 0, + + iJIT_CT_CODE, /**<\brief Executable code. */ + + iJIT_CT_DATA, /**<\brief Data (not executable code). + * VTune Amplifier uses the format string + * (see iJIT_Method_Update) to represent + * this data in the VTune Amplifier GUI */ + + iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace. + * Can be used for the following + * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events, + * if the type of the previously reported segment + * type is the same. */ + iJIT_CT_EOF +} iJIT_SegmentType; + +/** + * @brief Description of a dynamic update of the content within JIT-compiled method + * @details The JIT engine may generate the methods that are updated at runtime + * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update + * structure to describe the update of the content within a JIT-compiled method, + * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it. + * + * On the first Update event, VTune Amplifier copies the original code range reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and + * adds the modified range to the original method. For next update events, VTune Amplifier + * does the same but it uses the latest modified version of a code region for update. + * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event. + * Notes: + * - Multiple update events with different types for the same trace are allowed + * but they must be reported for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [code] Ignored + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + * - The types of previously reported events can be changed but they must be reported + * for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + */ + +typedef struct _iJIT_Method_Update +{ + void* load_address; /**<\brief Start address of the update within a method */ + + unsigned int size; /**<\brief The update size */ + + iJIT_SegmentType type; /**<\brief Type of the update */ + + const char* data_format; /**<\brief C string that contains a format string + * that follows the same specifications as format in printf. + * The format string is used for iJIT_CT_CODE only + * and cannot be NULL. + * Format can be changed on the fly. */ +} *piJIT_Method_Update, iJIT_Method_Update; + +/** @endcond */ + +#pragma pack(pop) + +/** @cond exclude_from_documentation */ +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifndef JITAPI_CDECL +# if defined WIN32 || defined _WIN32 +# define JITAPI_CDECL __cdecl +# else /* defined WIN32 || defined _WIN32 */ +# if defined _M_IX86 || defined __i386__ +# define JITAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define JITAPI_CDECL /* actual only on x86_64 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* defined WIN32 || defined _WIN32 */ +#endif /* JITAPI_CDECL */ + +#define JITAPI JITAPI_CDECL +/** @endcond */ + +/** + * @brief Generates a new unique method ID. + * + * You must use this API to obtain unique and valid method IDs for methods or + * traces reported to the agent if you don't have your own mechanism to generate + * unique method IDs. + * + * @return a new unique method ID. When out of unique method IDs, this API + * returns 0, which is not an accepted value. + */ +unsigned int JITAPI iJIT_GetNewMethodID(void); + +/** + * @brief Returns the current mode of the agent. + * + * @return iJIT_SAMPLING_ON, indicating that agent is running, or + * iJIT_NOTHING_RUNNING if no agent is running. + */ +iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void); + +/** + * @brief Reports infomation about JIT-compiled code to the agent. + * + * The reported information is used to attribute samples obtained from any + * Intel(R) VTune(TM) Amplifier collector. This API needs to be called + * after JIT compilation and before the first entry into the JIT-compiled + * code. + * + * @param[in] event_type - type of the data sent to the agent + * @param[in] EventSpecificData - pointer to event-specific data + * + * @returns 1 on success, otherwise 0. + */ +int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +/** @endcond */ + +/** @} jitapi group */ + +#endif /* __JITPROFILING_H__ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp new file mode 100644 index 0000000000..ef4c42bacf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp @@ -0,0 +1,317 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nchw_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void nchw_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t ws_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[ws_offset] = value; + } else + reinterpret_cast(ws)[ws_offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + auto s = src[src_offset]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH + : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + d[0] += src[src_offset]; + } + } + } + + d[0] = math::out_round((float)d[0] / num_summands); + }; + + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = nstl::numeric_limits::lowest(); + set_ws(mb, c, od, oh, ow, 0); + ker_max(d, mb, c, od, oh, ow); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = 0; + ker_avg(d, mb, c, od, oh, ow); + }); + } +} + +template +void nchw_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int mb, int c) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW; + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_offset++] = 0; + } + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto b_c = ws_d.blocking_desc().inner_nblks == 0 + ? 1 : ws_d.blocking_desc().inner_blks[0]; + auto ws_offset = is_3d + ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c + : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c; + + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + size_t diff_src_offset = + (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + size_t num_summands = (alg == pooling_avg_include_padding) + ? (size_t)KW*KH*KD + : (size_t)(id_end - id_start)*(ih_end - ih_start) + *(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0] / num_summands; + } + } + } + }; + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_max(d, mb, c, od, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_avg(d, mb, c, od, oh, ow); + } + } + } + }); + } +} + +template struct nchw_pooling_fwd_t; +template struct nchw_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp new file mode 100644 index 0000000000..bbdd04f6b9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp @@ -0,0 +1,147 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCHW_POOLING_HPP +#define CPU_NCHW_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct nchw_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct nchw_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + bool ws_ok = true + && hint_fwd_pd_ + && hint_fwd_pd_->workspace_md(); + if (!ws_ok) + return status::unimplemented; + + const auto &ws_blk = + hint_fwd_pd_->workspace_md()->format_desc.blocking; + ws_ok = ws_ok + && ws_blk.inner_nblks < 1 + && IMPLICATION(ws_blk.inner_nblks == 1, + ws_blk.inner_idxs[0] == 1); + if (!ws_ok) + return status::unimplemented; + + ws_md_ = *hint_fwd_pd_->workspace_md(); + } + + return status::success; + } + }; + + nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp new file mode 100644 index 0000000000..c0e93fefe4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp @@ -0,0 +1,382 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "ncsp_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void ncsp_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool calculate_stats = !pd()->stats_is_src(); + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = scratchpad.get(key_bnorm_tmp_mean); + variance = scratchpad.get(key_bnorm_tmp_var); + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t N = pd()->MB(); + dim_t C = pd()->C(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off; + + if (calculate_stats) { + data_t *mean_blk = mean + C_off; + data_t *variance_blk = variance + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = (c + C_off) * SP; + data_t sum = 0; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + sum += src[off + n * C * SP + sp]; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + mean_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + mean_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + mean_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sum = 0.; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + data_t m = src[off * SP + n * C * SP + sp] + - mean[off]; + sum += m * m; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + variance_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + variance_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + variance_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + } + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sqrt_variance + = static_cast(sqrtf(variance[off] + eps)); + data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + off] : 0; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + size_t d_off = off * SP + n * C * SP + sp; + data_t bn_res + = sm * (src[d_off] - mean[off]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void ncsp_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + if (diff_scaleshift == nullptr) + diff_scaleshift = scratchpad.get(key_bnorm_tmp_diff_ss); + + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t C = pd()->C(), N = pd()->MB(); + const bool use_scaleshift = pd()->use_scaleshift(); + const float eps = pd()->desc()->batch_norm_epsilon; + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = 2 * N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off; + + data_t *diff_gamma_blk = diff_scaleshift + C_off; + data_t *diff_beta_blk = diff_scaleshift + C + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t diff_gamma = 0.0, diff_beta = 0.0; + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + diff_gamma += (src[d_off] - v_mean) * dd; + diff_beta += dd; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = diff_gamma; + ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter + + SP_N_ithr * C_blks_per_iter + c] = diff_beta; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + data_t sqrt_variance = static_cast( + 1.0f / sqrtf(variance[c + C_off] + eps)); + diff_gamma_blk[c] = 0.; + diff_beta_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) { + diff_gamma_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + diff_beta_blk[c] += ws_reduce[ws_iter_off + + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter + + c]; + } + diff_gamma_blk[c] *= sqrt_variance; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t gamma = use_scaleshift ? scaleshift[off] : 1; + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[off] + eps)); + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_blk[c] / (SP * N) + + (src[d_off] - v_mean) * diff_gamma_blk[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp new file mode 100644 index 0000000000..97ca3b003f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp @@ -0,0 +1,160 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP +#define CPU_NCSP_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + using namespace format_tag; + + bool ok = true + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * C() * mkldnn_get_max_threads()); + + if (!is_training()) { + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C()); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C()); + } + } + } + }; + + typedef typename prec_traits::type data_t; + + ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace format_tag; + + bool ok = true + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward)) + scratchpad.book(key_bnorm_tmp_diff_ss, + sizeof(data_t) * 2 * C()); + } + }; + + typedef typename prec_traits::type data_t; + + ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp new file mode 100644 index 0000000000..38cfb28dce --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp @@ -0,0 +1,392 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nhwc_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define MEM_D(name) name##_d + +#define DECLARE_READ_STRIDES(name) \ + const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \ + const size_t name##_d_stride = (!is_3d) \ + ? 0 \ + : MEM_D(name).blocking_desc().strides[2]; \ + const size_t name##_h_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[2] \ + : MEM_D(name).blocking_desc().strides[3]; \ + const size_t name##_w_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[3] \ + : MEM_D(name).blocking_desc().strides[4]; + +namespace nhwc_pooling { + size_t strided_offset(const int _n, const size_t _sn, + const int _d, const size_t _sd, + const int _h, const size_t _sh, + const int _w, const size_t _sw) + { + return _n * _sn + + _d * _sd + + _h * _sh + + _w * _sw; + } +} + +template +void nhwc_pooling_fwd_t::array_div_by_const(const int n, + const data_t *src, const size_t num, data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + float ftmp = (float)src[i]; + ftmp = ftmp / num; + dst[i] = math::out_round(ftmp); + } +} + +template +void nhwc_pooling_fwd_t::array_add(const int n, const data_t *src, + data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + dst[i] += src[i]; + } +} + +template +void nhwc_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + using namespace nhwc_pooling; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper MEM_D(src)(pd()->src_md()); + const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + DECLARE_READ_STRIDES(src); + DECLARE_READ_STRIDES(dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + parallel_nd(MB, OD, OH, OW, + [&](int mb, int od, int oh, int ow) { + size_t dst_offset_init = strided_offset(mb, dst_n_stride, + od, dst_d_stride, + oh, dst_h_stride, + ow, dst_w_stride); + if (alg == pooling_max) { + size_t ws_offset_init = 0; + if (ws) + { + DECLARE_READ_STRIDES(ws); + ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + } + // Note: GCC 4.8.5 won't vectorize below + // simple loops unless they are singled out + // into separate helper routines: + // array_nhwc_initialize, array_nhwc_max + if (!ws) + array_nhwc_initialize(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + else + array_nhwc_initialize(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + + + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) + continue; + if (ih < 0 || ih >= IH) + continue; + if (iw < 0 || iw >= IW) + continue; + + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + + if (!ws) + array_nhwc_max(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + else + array_nhwc_max(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + } + } else { + // pooling_avg + auto d = dst + dst_offset_init; + + utils::array_set(d, 0, OC); + + auto id_start = apply_offset(od * SD, padF); + auto ih_start = apply_offset(oh * SH, padT); + auto iw_start = apply_offset(ow * SW, padL); + auto id_end = nstl::min(od * SD - padF + KD, ID); + auto ih_end = nstl::min(oh * SH - padT + KH, IH); + auto iw_end = nstl::min(ow * SW - padL + KW, IW); + + // it is cheaper to actually count this in a loop + // as the typical kernel is small + size_t num_summands = 0; + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + auto s = src + src_offset_init; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_add(OC, s, d); + + num_summands++; + } + + num_summands = (alg == pooling_avg_include_padding) ? + KW * KH * KD : num_summands; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_div_by_const(OC, d, num_summands, d); + } + }); +} + +template +void nhwc_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace nhwc_pooling; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); + const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int OC = pd()->C(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + auto alg = pd()->desc()->alg_kind; + + DECLARE_READ_STRIDES(diff_src); + DECLARE_READ_STRIDES(diff_dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + const int MB = pd()->MB(); + + parallel_nd(MB, ID, IH, IW, + [&](int mb, int id, int ih, int iw) { + size_t src_offset_init = strided_offset(mb, diff_src_n_stride, + id, diff_src_d_stride, + ih, diff_src_h_stride, + iw, diff_src_w_stride); + + // check if kernel windows are disjoint, in this case there's no + // update needed and we just write there once, no initialization + // required. + if (!(KD == SD && KH == SH && KW == SW)) + for (int oc = 0; oc < OC; ++oc) + diff_src[src_offset_init + oc] = data_type_t(0); + + // Find out which output cells may correspond to current + // input position. Current input postition divided by + // stride, with integer divide rounding down, is the + // right-most output. + // Left-most output may be computed if we decrement input + // by (kernel_size - 1) and then do the same division by + // stride. + int od_left = nstl::max((id + padF - KD + 1) / SD, 0); + int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0); + int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0); + // Notice +1 here to preserve the C loop "less than" + // condition for continuing the for loop. + int od_right = nstl::min((id + padF) / SD + 1 , OD); + int oh_right = nstl::min((ih + padT) / SH + 1 , OH); + int ow_right = nstl::min((iw + padL) / SW + 1 , OW); + + for (int od = od_left; od < od_right; ++od) + for (int oh = oh_left; oh < oh_right; ++oh) + for (int ow = ow_left; ow < ow_right; ++ow) { + const int kd = id - od*SD + padF; + const int kh = ih - oh*SH + padT; + const int kw = iw - ow*SW + padL; + + if (kd < 0 || kd >= KD) + continue; + if (kh < 0 || kh >= KH) + continue; + if (kw < 0 || kw >= KW) + continue; + + size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, + od, diff_dst_d_stride, + oh, diff_dst_h_stride, + ow, diff_dst_w_stride); + + if (alg == pooling_max) { + DECLARE_READ_STRIDES(ws); + size_t ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + const int index = kd * KH * KW + kh * KW + kw; + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const int index_from_ws = + (MEM_D(ws).data_type() == data_type::u8) + ? (int)ws[ws_offset_init + oc] + : ((int *)ws)[ws_offset_init + oc]; + + const data_t d = diff_dst[dst_offset_init + oc]; + + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += + (index_from_ws == index) + ? d + : data_type_t(0); + else + diff_src[src_offset_init + oc] = + (index_from_ws == index) + ? d + : data_type_t(0); + } + } else { + // pooling_avg + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) + ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const data_t d = diff_dst[dst_offset_init + oc]; + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += d / num_summands; + else + diff_src[src_offset_init + oc] = d / num_summands; + } + } + } + }); +} + +template struct nhwc_pooling_fwd_t; +template struct nhwc_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp new file mode 100644 index 0000000000..7e33b6869f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp @@ -0,0 +1,210 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NHWC_POOLING_HPP +#define CPU_NHWC_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace nhwc_pooling { +size_t strided_offset(const int _n, const size_t _sn, const int _d, + const size_t _sd, const int _h, const size_t _sh, const int _w, + const size_t _sw); +} + +template +struct nhwc_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + void array_div_by_const(const int n, const data_t *src, const size_t num, + data_t *dst) const; + void array_add(const int n, const data_t *src, data_t *dst) const; + + template + void array_nhwc_max(const int n, data_t *dst, const data_t *src, + unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, + const int index) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < n; ++oc) { + auto s = src[oc]; + data_t mv = dst[oc]; + + // update index of maximum +#if defined __INTEL_COMPILER + if ((use_workspace) && (s > mv)) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + ws[ws_offset + oc] = index; + } else + reinterpret_cast(ws)[ws_offset + oc] = index; + } +#else + // Need to add explicit predicates for GCC to vectorize this. + // And although the resulting code is ugly, it is still 4 times + // faster than scalar + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + unsigned char predicate = (s > mv) ? 0xff : 0; + unsigned char current_value = ws[ws_offset + oc]; + current_value = (predicate & (unsigned char)index) + | ((~predicate) & current_value); + ws[ws_offset + oc] = current_value; + } else { + auto wint = reinterpret_cast(ws); + unsigned int predicate = (s > mv) ? 0xffffffff : 0; + unsigned int current_value = wint[ws_offset + oc]; + current_value = (predicate & (unsigned int)index) + | ((~predicate) & current_value); + wint[ws_offset + oc] = current_value; + } + } +#endif + // update maximum + dst[oc] = nstl::max(s, mv); + } + } + + template + void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws, + const size_t ws_offset, const data_type_t ws_dt) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + for (int oc = 0; oc < n; ++oc) { + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + ws[ws_offset + oc] = 0; + } else + reinterpret_cast(ws)[ws_offset + oc] = 0; + } + dst[oc] = nstl::numeric_limits::lowest(); + } + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct nhwc_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +}// namespace cpu +}// namespace impl +}// namespace mkldnn + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp new file mode 100644 index 0000000000..e20333e66f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp @@ -0,0 +1,288 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "nspc_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void nspc_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + const bool with_relu = pd()->with_relu_post_op(); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_mean = scratchpad.get(key_bnorm_tmp_mean); + auto tmp_var = scratchpad.get(key_bnorm_tmp_var); + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = tmp_mean; + variance = tmp_var; + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->H() * pd()->W() * pd()->D(); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; + data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; + + if (calculate_stats) { + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] = 0.; + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] += src[(size_t)n * SP * C + + sp * C + c]; + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + mean[c] = 0; + for (dim_t n = 0; n < nthr; n++) + mean[c] += ws_reduce[C * n + c]; + mean[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + mean_loc[c] = mean[c]; + ws_reduce[C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) { + data_t m = src[(size_t)n * SP * C + sp * C + c] + - mean_loc[c]; + ws_reduce[C * ithr + c] += m * m; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + variance[c] = 0; + for (dim_t n = 0; n < nthr; n++) + variance[c] += ws_reduce[C * n + c]; + variance[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) + variance_loc[c] = variance[c]; + } else { + variance_loc = variance; + mean_loc = mean; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + data_t sqrt_variance = static_cast( + sqrtf(variance_loc[c] + eps)); + data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + c] : 0; + size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void nspc_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_diff_ss = scratchpad.get(key_bnorm_tmp_diff_ss); + + if (diff_scaleshift == nullptr) + diff_scaleshift = tmp_diff_ss; + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->D() * pd()->H() * pd()->W(); + data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; + auto *ws_reduce = scratchpad.get(key_bnorm_reduction); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + + data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; + data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); + + for (dim_t c = 0; c < C; c++) { + ws_reduce[C * ithr + c] = 0.; + ws_reduce[C * nthr + C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd; + ws_reduce[C * nthr + C * ithr + c] += dd; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[c] + eps)); + diff_gamma[c] = 0; + diff_beta[c] = 0; + for (dim_t n = 0; n < nthr; n++) { + diff_gamma[c] += ws_reduce[C * n + c]; + diff_beta[c] += ws_reduce[C * nthr + C * n + c]; + } + diff_gamma[c] *= sqrt_variance; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + diff_gamma_loc[c] = diff_gamma[c]; + diff_beta_loc[c] = diff_beta[c]; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t gamma = use_scaleshift ? scaleshift[c] : 1; + data_t sqrt_variance + = static_cast(1.0f / sqrtf(variance[c] + eps)); + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_loc[c] / (SP * N) + + (src[d_off] - mean[c]) * diff_gamma_loc[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp new file mode 100644 index 0000000000..aad86b05a7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp @@ -0,0 +1,169 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP +#define CPU_NSPC_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct nspc_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + dim_t sz = nstl::max(C(), 16) * mkldnn_get_max_threads(); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz); + } + } + }; + + typedef typename prec_traits::type data_t; + + nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct nspc_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C() + * (mkldnn_get_max_threads() + 1)); + } + }; + + typedef typename prec_traits::type data_t; + + nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp new file mode 100644 index 0000000000..d79b1a034b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp @@ -0,0 +1,265 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "ref_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + /* fast return */ + if (this->pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN); + auto variance = pd()->stats_is_src() + ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift();; + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op = [&](float res) { + return (with_relu && res < 0.0f) ? 0.0f : res; + }; + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + float v_mean = calculate_stats ? 0 : mean[c]; + float v_variance = calculate_stats ? 0 : variance[c]; + + if (calculate_stats) { + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) + v_mean += src[data_offset(data_d, n, c, d, h, w)]; + v_mean /= W*N*H*D; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean; + v_variance += m*m; + } + v_variance /= W*H*N*D; + } + + float sqrt_variance = sqrtf(v_variance + eps); + float sm = (use_scaleshift + ? scaleshift[scaleshift_d.off(0, c)] + : 1.0f) / sqrt_variance; + float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + auto d_off = data_offset(data_d,n,c,d,h,w); + float bn_res = sm * ((float)src[d_off] - v_mean) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + if (data_type == data_type::s8) { + dst[d_off] = qz_a1b0()(maybe_post_op(bn_res)); + } else { + dst[d_off] = static_cast(maybe_post_op(bn_res)); + } + } + + if (calculate_stats) { + if (save_stats) { + mean[c] = v_mean; + variance[c] = v_variance; + } + } + }); +} + +template struct ref_batch_normalization_fwd_t; +template struct ref_batch_normalization_fwd_t; + +template +void ref_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md()); + + const dim_t C = pd()->C(); + + /* fast return */ + if (this->pd()->has_zero_dim_memory()) { + if (diff_scaleshift) { + for (dim_t c = 0; c < C; ++c) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0; + } + } + return; + } + + const dim_t N = pd()->MB(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + data_t v_mean = mean[c]; + data_t v_variance = variance[c]; + data_t sqrt_variance = static_cast(1.0f / sqrtf(v_variance + eps)); + data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1; + data_t diff_gamma = data_t(0); + data_t diff_beta = data_t(0); + diff_gamma = 0.0; + diff_beta = 0.0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + diff_gamma += (src[s_off] - v_mean) * dd; + diff_beta += dd; + } + diff_gamma *= sqrt_variance; + + if (diff_scaleshift) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta; + } + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w); + data_t dd = diff_dst[dd_off]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + data_t v_diff_src = dd; + if (calculate_diff_stats) { + v_diff_src -= diff_beta/(D*W*H*N) + + (src[s_off] - v_mean) * + diff_gamma*sqrt_variance/(D*W*H*N); + } + v_diff_src *= gamma*sqrt_variance; + diff_src[dd_off] = v_diff_src; + } + }); +} + +template struct ref_batch_normalization_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp new file mode 100644 index 0000000000..aa9f74125a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_BATCH_NORMALIZATION_HPP +#define CPU_REF_BATCH_NORMALIZATION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && IMPLICATION(use_scaleshift(), + weights_md()->data_type == data_type::f32) + && (attr()->has_default_values() || with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (src_md()->data_type == data_type::s8 && !stats_is_src()) + return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + return status::success; + } + }; + + ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t); + + status_t init() { + bool ok = true + && is_bwd() + && utils::everyone_is(data_type, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type, + weights_md()->data_type, + diff_weights_md()->data_type)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp new file mode 100644 index 0000000000..4c534b5508 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_CONCAT_HPP +#define REF_CONCAT_HPP + +#include "reorder_pd.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_CONCAT_PD_T("ref:any", ref_concat_t); + + status_t init() { + bool ok = cpu_concat_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + const primitive_attr_t attr; /* alpha == 1. */ + reorder_pd_t *r_pd = nullptr; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, src_image_md(i)) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector reorder_pds_; + }; + + ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_concat_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp new file mode 100644 index 0000000000..c0a979c4cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp @@ -0,0 +1,395 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "type_helpers.hpp" + +#include "ref_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template +void ref_convolution_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool with_relu = 0; // TODO: change if support post_ops + const float nslope = 0.f; + + const int ndims = pd()->desc()->src_desc.ndims; + + auto ker = [=](int g, int mb, int oc, int od, int oh, + int ow) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * KSD - padFront + kd * (1 + KDD); + const int ih = oh * KSH - padT + kh * (1 + KDH); + const int iw = ow * KSW - padL + kw * (1 + KDW); + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + if (ndims == 5) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + + } + return d; + }; + + parallel_nd(G, MB, OC, OD, OH, OW, + [&](int g, int mb, int oc, int od, int oh, int ow) { + float a = bias + ? get_bias(bias, bias_d.off(g * OC + oc), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, oc, od, oh, ow); + if (with_relu && a < 0) + a = a * nslope; + if (ndims == 5) + dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate(a); + else if (ndims == 4) + dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate(a); + else if (ndims == 3) + dst[dst_d.off(mb, g*OC + oc, ow)] = saturate(a); + else + assert(false); + }); +} + +template +void ref_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->diff_src_desc.ndims; + + auto ker = [=](int g, int mb, int ic, int id, int ih, + int iw) { + acc_data_t d = 0; + for (int oc = 0; oc < OC; ++oc) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + if (iw + padL < kw * (1 + KDW) + || ih + padT < kh * (1 + KDH) + || id + padFront < kd * (1 + KDD)) + continue; + int ow = iw - kw * (1 + KDW) + padL; + int oh = ih - kh * (1 + KDH) + padT; + int od = id - kd * (1 + KDD) + padFront; + if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) + continue; + + ow /= KSW; + oh /= KSH; + od /= KSD; + + if (od < OD && oh < OH && ow < OW) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, od, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + } + } + return d; + }; + + parallel_nd(G, MB, IC, ID, IH, IW, + [&](int g, int mb, int ic, int id, int ih, int iw) { + auto ds_idx = (ndims == 5) + ? diff_src_d.off(mb, g*IC + ic, id, ih, iw) + : (ndims == 4) + ? diff_src_d.off(mb, g*IC + ic, ih, iw) + : diff_src_d.off(mb, g*IC + ic, iw); + float a = bias + ? get_bias(bias, bias_d.off(g * IC + ic), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, ic, id, ih, iw); + diff_src[ds_idx] = saturate(a); + }); +} + +template +void ref_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->src_desc.ndims; + +auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ow*KSW + kw * (1 + KDW) < padL + || oh*KSH + kh * (1 + KDH) < padT + || od*KSD + kd * (1 + KDD) < padFront + || ow*KSW + kw * (1 + KDW) >= IW + padL + || oh*KSH + kh * (1 + KDH) >= IH + padT + || od*KSD + kd * (1 + KDD) >= ID + padFront) + continue; + + int id = od*KSD - padFront + kd * (1 + KDD); + int ih = oh*KSH - padT + kh * (1 + KDH); + int iw = ow*KSW - padL + kw * (1 + KDW); + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, + oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)] + * src[src_d.off(mb, g*IC + ic, ih, iw)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)] + * src[src_d.off(mb, g*IC + ic, iw)]; + else + assert(false); + } + }; + + auto ker_bias = [=](acc_data_t &d, int g, int oc) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh, + ow)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, + ow)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]; + else + assert(false); + } + }; + + parallel_nd(G, OC, [&](int g, int oc) { + if (diff_bias) { + // XXX: loss of precision when bias is a float... + acc_data_t db = 0; + ker_bias(db, g, oc); + diff_bias[diff_bias_d.off(g*OC+oc)] + = saturate(db); + } + + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t dw = 0; + ker(dw, g, oc, ic, kd, kh, kw); + + if (ndims == 5) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kd, kh, kw) + : diff_weights_d.off(oc, ic, kd, kh, kw); + diff_weights[idx] = saturate(dw); + } else if (ndims == 4) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kh, kw) + : diff_weights_d.off(oc, ic, kh, kw); + diff_weights[idx] = saturate(dw); + } else if (ndims == 3) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kw) + : diff_weights_d.off(oc, ic, kw); + diff_weights[idx] = saturate(dw); + } else { + assert(false); + } + } + }); +} + +using namespace data_type; + +template struct ref_convolution_fwd_t; + +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; +template struct ref_convolution_fwd_t; + +template struct ref_convolution_bwd_data_t; + +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; +template struct ref_convolution_bwd_data_t; + +template struct ref_convolution_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp new file mode 100644 index 0000000000..7c83d0c6d4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_CONVOLUTION_HPP +#define CPU_REF_CONVOLUTION_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, data_type::undef, + dst_type, acc_type) + && IMPLICATION(with_bias(), true + && IMPLICATION(src_type == u8, + utils::one_of(bias_md_.data_type, f32, s32, s8, u8)) + && IMPLICATION(src_type == f32, + bias_md_.data_type == f32)) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; + } + + virtual bool support_bias() const override { return true; } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_wei_type, diff_wei_type, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type diff_wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp new file mode 100644 index 0000000000..541a303aab --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp @@ -0,0 +1,199 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_deconvolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OD = pd()->OD(); + const int OC = pd()->OC() / G; + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(MB, G, OC, OD, OH, OW, + [&](int mb, int g, int oc, int od, int oh, int ow) { + auto b = bias[g * OC + oc]; + switch (ndims) { + case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break; + case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break; + case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break; + default: assert(!"invalid dimension size"); + } + }); +} + +void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW()*pd()->OH()*pd()->OD(); + + parallel_nd(MB, OC, [&](int mb, int oc) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + dst[offset] += bias[oc]; + } + }); +} + +template +void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW() * pd()->OH() * pd()->OD(); + + const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0]; + + parallel_nd(MB, utils::div_up(OC, blksize), SP, + [&](int mb, int oc_blk, int sp) { + int oc = oc_blk * blksize; + auto offset = mb * stride_mb + oc * SP + sp * blksize; + const int blk = nstl::min(blksize, OC - oc); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + dst[offset + i] += bias[oc + i]; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst, + data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OC = pd()->OC() / G; + const int OD = pd()->OD(); + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(G, OC, [&](int g, int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + switch (ndims) { + case 5: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, od, oh, ow)]; + break; + case 4: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, oh, ow)]; + break; + case 3: + db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)]; + break; + default: assert(!"invalid dimension size"); + } + } + } + } + } + diff_bias[g * OC + oc] = db; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH()*pd()->OW()*pd()->OD(); + + parallel_nd(OC, [&](int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + db += diff_dst[offset]; + } + } + diff_bias[oc] = db; + }); +} + +template +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH() * pd()->OW() * pd()->OD(); + + const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0]; + + parallel_nd(utils::div_up(OC, blksize), [&](int ocb) { + data_t db[blksize] = {0}; + + for (int mb = 0; mb < MB; ++mb) { + for (int sp = 0; sp < SP; ++sp) { + auto offset = mb * stride_mb + (ocb * SP + sp) * blksize; + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; ++i) + db[i] += diff_dst[offset+i]; + } + } + + const int blk = nstl::min(blksize, OC - ocb * blksize); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + diff_bias[ocb * blksize + i] = db[i]; + }); +} + +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp new file mode 100644 index 0000000000..d61903c32d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp @@ -0,0 +1,502 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_DECONVOLUTION_HPP +#define CPU_REF_DECONVOLUTION_HPP + +#include +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static status_t compute_blocked_format(bool with_groups, + const memory_desc_t *oi_md, memory_desc_t *io_md) +{ + /* Computes blocking for *i*o* format from *o*i* format */ + + bool sanity_check_ok = true + && oi_md->ndims == io_md->ndims + && oi_md->format_kind == format_kind::blocked; + if (!sanity_check_ok) return status::invalid_arguments; + + const blocking_desc_t &oi_blk = oi_md->format_desc.blocking; + blocking_desc_t io_blk = io_md->format_desc.blocking; + + io_md->format_kind = format_kind::blocked; + io_blk = oi_blk; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]); + for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) { + if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) { + io_blk.inner_idxs[i_blk] = + (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC); + } + } + + return memory_desc_init_by_blocking_desc(*io_md, io_blk); +} + +static status_t conv_descr_create(const deconvolution_desc_t *dd, + convolution_desc_t *cd) +{ + using namespace prop_kind; + alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct + ? alg_kind::convolution_direct : alg_kind::convolution_winograd; + + const memory_desc_t *src_md, *dst_md, *d_weights_d; + prop_kind_t prop_kind; + memory_desc_t c_weights_d; + if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { + prop_kind = backward_data; + src_md = &dd->dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->weights_desc; + } else if (dd->prop_kind == backward_data) { + prop_kind = forward_training; + src_md = &dd->diff_dst_desc; + dst_md = &dd->diff_src_desc; + d_weights_d = &dd->weights_desc; + } else { + prop_kind = dd->prop_kind; + src_md = &dd->diff_dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->diff_weights_desc; + } + + const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; + + /* create weights desc for convolution */ + c_weights_d = *d_weights_d; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]); + nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]); + nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]); + + if (c_weights_d.format_kind != format_kind::any) + CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d)); + + return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, + prop_kind != backward_weights ? &dd->bias_desc : nullptr, + dst_md, dd->strides, dd->dilates, + dd->padding[0], dd->padding[1], dd->padding_kind); +} + +struct ref_deconvolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , conv_supports_bias_(other.conv_supports_bias_) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + CHECK(conv_descr_create(desc(), &cd)); + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + conv_supports_bias_ = + static_cast(conv_pd_) + ->support_bias(); + bool output_f32 = utils::everyone_is(data_type::f32, + desc()->accum_data_type, desc()->dst_desc.data_type); + + bool ok = true + && conv_pd_->weights_md()->extra.flags == 0 + /* deconv reference code can process only f32 bias */ + && IMPLICATION(with_bias(), + conv_supports_bias_ || output_f32); + if (ok) return status::success; + + delete conv_pd_; + } + conv_pd_ = nullptr; + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->post_ops_.has_default_values(); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (dst_md_.format_kind == format_kind::any) + dst_md_ = *conv_pd_->diff_src_md(); + if (bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + bool conv_supports_bias_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_fwd_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + if (pd()->with_bias() && pd()->conv_supports_bias_) + conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS); + conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + + if (pd()->with_bias() && !pd()->conv_supports_bias_) { + using namespace format_tag; + + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_fwd_bias_ncdhw(bias, dst); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_fwd_bias_nCdhwXc<8>(bias, dst); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_fwd_bias_nCdhwXc<16>(bias, dst); + break; + default: + compute_fwd_bias(bias, dst); + break; + } + } + return status::success; + } + +private: + void compute_fwd_bias(const data_t *bias, data_t *dst) const; + void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const; + template void compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_data_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_data_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + + return status::unimplemented; + } + + status_t init() { + using namespace data_type; + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && utils::everyone_is(data_type::f32, + desc()->diff_src_desc.data_type, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (diff_src_md_.format_kind == format_kind::any) + diff_src_md_ = *conv_pd_->dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_data_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_weights_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->diff_weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type::f32, + desc()->src_desc.data_type, + desc()->diff_weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->has_default_values(); + if (ok) { + CHECK(init_convolution()); + if (diff_weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->diff_weights_md(), + &desc_.diff_weights_desc)); + diff_weights_md_ = desc_.diff_weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + if (diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits::type data_t; + + ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_weights_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + status_t status = conv_p_->execute(conv_ctx); + if (status != status::success) return status; + + if (pd()->with_bias()) { + using namespace format_tag; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_bwd_bias_ncdhw(diff_dst, diff_bias); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias); + break; + default: + compute_bwd_bias(diff_dst, diff_bias); + break; + } + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const; + void compute_bwd_bias_ncdhw(const data_t *diff_dst, + data_t *diff_bias) const; + template void compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const; + + primitive_t *conv_p_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp new file mode 100644 index 0000000000..7beee8d323 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp @@ -0,0 +1,297 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace alg_kind; +using namespace math; + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, + float beta): alg_(alg), alpha_(alpha), beta_(beta) { + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); +} + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( + const post_ops_t::entry_t::eltwise_t &eltwise) + : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {} + +float ref_eltwise_scalar_fwd_t::compute_scalar(float s) { + switch (alg_) { + case eltwise_relu: return relu_fwd(s, alpha_); + case eltwise_tanh: return tanh_fwd(s); + case eltwise_elu: return elu_fwd(s, alpha_); + case eltwise_square: return square_fwd(s); + case eltwise_abs: return abs_fwd(s); + case eltwise_sqrt: return sqrt_fwd(s); + case eltwise_linear: return linear_fwd(s, alpha_, beta_); + case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_); + case eltwise_soft_relu: return soft_relu_fwd(s); + case eltwise_logistic: return logistic_fwd(s); + default: assert(!"unknown eltwise alg_kind"); + } + + return 0.f; +} + +template +void ref_eltwise_fwd_t::execute_forward_nCspBc_padded( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + const blocking_desc_t &blk = data_d.blocking_desc(); + const int block = blk.inner_blks[0]; + + const int MB = pd()->MB(); + const int C = pd()->C() / block; + const int C_PADDED = data_d.padded_dims()[1] / block; + const int tail = pd()->C() % block; + const int SP = pd()->D() * pd()->H() * pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + auto ker = [=] (data_t &d, data_t s) { + switch (alg_kind) { + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }; + + // FIXME: integer overflow? + + parallel_nd(MB, C_PADDED, SP, + [&](int n, int c, int sp) { + auto d_off = (n*C_PADDED*SP + c*SP + sp) * block; + if (c < C) { + for (int v = 0; v < block; v++) + ker(dst[d_off + v], src[d_off + v]); + } else { + for (int v = 0; v < tail; v++) + ker(dst[d_off + v], src[d_off + v]); + } + }); +} + +template +void ref_eltwise_fwd_t::execute_forward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int id, int h, int w) { + auto d_off = is_3d + ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w); + data_t s = src[d_off]; + data_t &d = dst[d_off]; + switch (alg_kind) { + case eltwise_relu: d = relu_fwd(s, alpha); break; + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_fwd_t::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const ptrdiff_t nelems = static_cast(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + dst += data_d.offset0(); + + if (alg_kind == eltwise_relu) { + // a fast path for relu as the most popular activation + parallel_nd(nelems, [&](ptrdiff_t e) { + dst[e] = relu_fwd(src[e], alpha); + }); + return; + } + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t s = src[e]; + data_t &d = dst[e]; + + switch (alg_kind) { + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_bwd_t::execute_backward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int d, int h, int w) { + auto data_off = is_3d + ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w); + auto diff_data_off = is_3d + ? diff_data_d.off(n, c, d, h, w) + : diff_data_d.off(n, c, h, w); + data_t s = src[data_off]; + data_t dd = diff_dst[diff_data_off]; + data_t &ds = diff_src[diff_data_off]; + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: + ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: + ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template +void ref_eltwise_bwd_t::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const ptrdiff_t nelems = static_cast(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t dd = diff_dst[e]; + const data_t s = src[e]; + data_t &ds = diff_src[e]; + + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; +template struct ref_eltwise_fwd_t; + +template struct ref_eltwise_bwd_t; +template struct ref_eltwise_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp new file mode 100644 index 0000000000..8f4ab35413 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp @@ -0,0 +1,168 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_ELTWISE_HPP +#define CPU_REF_ELTWISE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_eltwise_scalar_fwd_t { +public: + ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta); + + // note that eltwise.scale is ignored + ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise); + + float compute_scalar(float s); + + const alg_kind_t alg_; + const float alpha_; + const float beta_; +}; + +template +struct ref_eltwise_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t); + + status_t init() { + using namespace utils; + + auto src_d = memory_desc_wrapper(src_md()); + + use_dense_ = false + || src_d.is_dense() + || (src_d.is_dense(true) && is_zero_preserved()); + + use_nCspBc_padded_ = !use_dense_ + && src_d.blocking_desc().inner_nblks == 1 + && one_of(src_d.blocking_desc().inner_blks[0], 8, 16) + && src_d.blocking_desc().inner_idxs[0] == 1 + && src_d.only_padded_dim(1) + && src_d.is_dense(true); + + if (has_zero_dim_memory()) + use_dense_ = use_nCspBc_padded_ = false; + + const bool use_generic = !use_dense_ && !use_nCspBc_padded_; + + bool ok = true + && is_fwd() + && everyone_is(data_type, desc()->data_desc.data_type) + && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + + bool use_dense_, use_nCspBc_padded_; + }; + + ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_forward_dense(ctx); + else if (pd()->use_nCspBc_padded_) + execute_forward_nCspBc_padded(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const; + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_eltwise_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && !is_fwd() + && everyone_is(data_type, + desc()->data_desc.data_type, + desc()->diff_data_desc.data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + auto diff_dst_d = memory_desc_wrapper(diff_dst_md()); + const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md()); + + use_dense_ = true + && same_fmt_ + && diff_dst_d.is_dense(true) + && is_zero_preserved() + && !has_zero_dim_memory(); + const bool use_generic = !use_dense_; + + if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5)) + return status::unimplemented; + + return status::success; + } + + bool use_dense_; + }; + + ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp new file mode 100644 index 0000000000..c807a9ffd0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp @@ -0,0 +1,285 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template +void ref_inner_product_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5); + const int ndims = src_d.ndims() - 2; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + auto ker_has_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int ic = 0; ic < IC; ++ic) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + switch (ndims) { + case 3: + d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)] + * weights[weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)] + * weights[weights_d.off(oc, ic, kh, kw)]; + break; + case 1: + d += (acc_data_t)src[src_d.off(mb, ic, kw)] + * weights[weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + return d; + }; + + auto ker_no_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) { + d += (acc_data_t)src[src_d.off(mb, ic)] + * weights[weights_d.off(oc, ic)]; + } + return d; + }; + + parallel_nd(MB, OC, [&](int mb, int oc) { + float a = bias + ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type) + : 0; + if (src_has_spatial) + a += ker_has_spatial(mb, oc); + else + a += ker_no_spatial(mb, oc); + if (do_relu && a < (acc_data_t)0) + a *= nslope; + dst[dst_d.off(mb, oc)] = saturate(a); + }); +} + +using namespace data_type; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; +template struct ref_inner_product_fwd_t; + +template +void ref_inner_product_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool diff_src_has_spatial + = utils::one_of(diff_src_d.ndims(), 3, 4, 5); + const int ndims = diff_src_d.ndims() - 2; + + parallel_nd(MB, IC, [&](int mb, int ic) { + if (diff_src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + switch (ndims) { + case 3: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kd, kh, kw)]); + break; + case 2: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kh, kw)]); + break; + case 1: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kw)]); + break; + default: assert(!"unsupported ndims size"); + } + } + switch (ndims) { + case 3: + diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] + = (diff_src_data_t)ds; + break; + case 2: + diff_src[diff_src_d.off(mb, ic, kh, kw)] + = (diff_src_data_t)ds; + break; + case 1: + diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds; + break; + default: assert(!"unsupported ndims size"); + } + } + } else { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] * + weights[weights_d.off(oc, ic)]); + } + diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds; + } + }); +} + +template struct ref_inner_product_bwd_data_t; + +template +void ref_inner_product_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5); + const int ndims = src_d.ndims() - 2; + + parallel_nd(OC, IC, [&](int oc, int ic) { + if (src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + data_t *dw(nullptr); + switch (ndims) { + case 3: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kh, kw)]; + break; + case 1: + dw = &diff_weights[diff_weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + switch (ndims) { + case 3: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kd, kh, kw)]; + break; + case 2: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kh, kw)]; + break; + case 1: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + } else { + data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)]; + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + *dw += diff_dst[diff_dst_d.off(mb, oc)] * + src[src_d.off(mb, ic)]; + } + } + }); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + + parallel_nd(OC, [&](int oc) { + data_t *db = &diff_bias[oc]; + *db = data_t(0); + for (int mb = 0; mb < MB; ++mb) + *db += diff_dst[diff_dst_d.off(mb, oc)]; + }); + } +} + +template struct ref_inner_product_bwd_weights_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp new file mode 100644 index 0000000000..bf87dbd514 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_INNER_PRODUCT_HPP +#define CPU_REF_INNER_PRODUCT_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && src_md()->data_type == src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && dst_md()->data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && diff_src_md()->data_type == diff_src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && diff_dst_md()->data_type == diff_dst_type + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type diff_src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type diff_dst_data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type, + src_md()->data_type, + diff_dst_md()->data_type, + diff_weights_md()->data_type) + && IMPLICATION(with_bias(), + data_type == diff_weights_md(1)->data_type) + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp new file mode 100644 index 0000000000..325e97963b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp @@ -0,0 +1,252 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static inline float fast_negative_powf(float omega, float beta) { + float Y; +/* + * Y = omega^(-3/4) = + * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega)) + * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega) / omega) + * = sqrtf(1.0f / (sqrtf(omega) * omega)) + */ + if (beta == 0.75f) { + Y = sqrtf(1.0f / (sqrtf(omega) * omega)); + } else { + Y = 1.0f / powf(omega, beta); + } + return Y; +}; + +template +template +void ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize + + h * W * blksize + w * blksize + c % blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const float alpha = static_cast(pd()->desc()->lrn_alpha); + const float beta = static_cast(pd()->desc()->lrn_beta); + const float k = static_cast(pd()->desc()->lrn_k); + + const int size = pd()->desc()->local_size; + const int half_size = (size - 1) / 2; + + float sum = 0; + if (across_channels) { + const int c_st = nstl::max(oc - half_size + 0, 0); + const int c_en = nstl::min(oc + half_size + 1, C); + + for (int c = c_st; c < c_en; ++c) { + const float s = src[data_off(mb, c, oh, ow)]; + sum += s * s; + } + } else { + int h_st = nstl::max(oh - half_size + 0, 0); + int h_en = nstl::min(oh + half_size + 1, H); + int w_st = nstl::max(ow - half_size + 0, 0); + int w_en = nstl::min(ow + half_size + 1, W); + for (int h = h_st; h < h_en; ++h) { + for (int w = w_st; w < w_en; ++w) { + const float s = src[data_off(mb, oc, h, w)]; + sum += s * s; + } + } + } + const int summands = across_channels ? size : size * size; + sum = k + alpha * sum / summands; + size_t off = data_off(mb, oc, oh, ow); + d[0] = static_cast(src[off] * fast_negative_powf(sum, beta)); + }; + + const int MB = pd()->MB(); + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&dst[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&dst[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&dst[off], mb, c, h, w); + }); + } +} + +template +template +void ref_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + const float alpha = static_cast(pd()->desc()->lrn_alpha); + const float beta = static_cast(pd()->desc()->lrn_beta); + const float k = static_cast(pd()->desc()->lrn_k); + const int kernel_size = pd()->desc()->local_size; + const int half_ksize = (kernel_size - 1) / 2; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize + + h * W * blksize + w * blksize + c%blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const int c_st = nstl::max(oc - half_ksize + 0, 0); + const int c_en = nstl::min(oc + half_ksize + 1, C); + + float A = 0, B = 0, omega_mid = 0; + for (int c = c_st; c < c_en; c++) { + float sum = 0.0; + const int i_st = nstl::max(c - half_ksize, 0); + const int i_en = nstl::min(c + kernel_size - half_ksize, C); + + for (int i = i_st; i < i_en; ++i) { + const float value = src[data_off(mb, i, oh, ow)]; + sum += value * value; + } + const float omega = static_cast(k + sum * alpha / kernel_size); + if (c == oc) omega_mid = omega; + float t = src[data_off(mb, c, oh, ow)] + * fast_negative_powf(omega, beta); + B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)]; + } + + const size_t off = data_off(mb, oc, oh, ow); + A = fast_negative_powf(omega_mid, beta) * diff_dst[off]; + B *= src[off]; + B *= (2.0f * alpha * beta) / kernel_size; + *d = static_cast(A - B); // final cast down to data_t + }; + + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&diff_src[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&diff_src[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&diff_src[off], mb, c, h, w); + }); + } +} + +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t:: +execute_forward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t:: +execute_backward(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp new file mode 100644 index 0000000000..f25cfb7fae --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_LRN_HPP +#define CPU_REF_LRN_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t); + + status_t init() { + using namespace format_tag; + + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_forward(ctx); break; + case nChw8c: execute_forward(ctx); break; + case nchw: execute_forward(ctx); break; + case nhwc: execute_forward(ctx); break; + default: execute_forward(ctx); + } + return status::success; + } + +private: + template + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t); + + status_t init() { + using namespace format_tag; + using namespace alg_kind; + + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, lrn_across_channels + /*, lrn_within_channel */) // not supported yet + && utils::everyone_is(data_type, + src_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_backward(ctx); break; + case nChw8c: execute_backward(ctx); break; + case nchw: execute_backward(ctx); break; + case nhwc: execute_backward(ctx); break; + default: execute_backward(ctx); + } + return status::success; + } + +private: + template + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp new file mode 100644 index 0000000000..65b934e123 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp @@ -0,0 +1,381 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "ref_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_pooling_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t offset = is_3d + ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[offset] = value; + } else + reinterpret_cast(ws)[offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, 1, oh, ow, kh*KW + kw); + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + acc_data_t dst = 0; + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, ih, iw)]; + } + } + + d[0] = math::out_round((float)dst / num_summands); + }; + + auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, id, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + acc_data_t dst = 0; + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, id, ih, iw)]; + } + } + } + + d[0] = math::out_round((float)dst / num_summands); + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (alg == pooling_max) { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = nstl::numeric_limits::lowest(); + set_ws(mb, oc, od, oh, ow, 0); + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + }); + } else { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = 0; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + }); + } +} + +template +void ref_pooling_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int _mb, int _oc) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0); + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) { + const size_t ws_off = ws_d.off(mb, oc, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = index / KW; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands; + } + } + }; + + auto ker_zero_3d = [=](int _mb, int _oc) { + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] = + data_type_t(0); + } + } + } + }; + + auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + const size_t ws_off = ws_d.off(mb, oc, od, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0]; + }; + + auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands; + } + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (pd()->desc()->alg_kind == alg_kind::pooling_max) { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + } + } + } + }); + } +} + +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; +template struct ref_pooling_fwd_t; + +template struct ref_pooling_bwd_t; +template struct ref_pooling_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp new file mode 100644 index 0000000000..e43ceaa82b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_POOLING_HPP +#define CPU_REF_POOLING_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && desc()->accum_data_type == acc_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct ref_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::everyone_is(data_type, diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits::type data_t; + typedef typename prec_traits::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp new file mode 100644 index 0000000000..af27743110 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp @@ -0,0 +1,153 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_shuffle.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace format_tag; + +template +template +void ref_shuffle_t::execute_(const exec_ctx_t &ctx) const { + using namespace prop_kind; + using namespace utils; + + const memory_desc_wrapper data_d(pd()->data_md()); + + auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST; + auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC; + auto input = CTX_IN_MEM(const data_t *, i_arg); + auto output = CTX_OUT_MEM(data_t *, o_arg); + + const int axis = pd()->axis(); + const int axis_size = pd()->axis_size(); + + const int MB = pd()->MB(); + const int C = pd()->C(); + int H = 1, W = 1, D = 1, HW = 1, SP = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5); + if (has_spatial) + { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + HW = H * W; + SP = D * HW; + } + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8; + + if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) { +#if MKLDNN_THR == MKLDNN_THR_OMP +# pragma omp parallel for collapse(3) schedule(static) + for (int mb = 0; mb < MB; ++mb) + for (int cb = 0; cb < C; cb += blksize) + for (int sp = 0; sp < SP; ++sp) { + const size_t off = mb * stride_mb + sp * blksize; + const size_t output_off = off + cb * SP; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + } +#else + parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c, + int sp) { + const size_t off = mb * stride_mb + sp * blksize; + const int cb = c * blksize; + const size_t output_off = off + cb * SP; + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + }); +#endif + } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) { + parallel_nd(MB, SP, [&](int mb, int sp) { + const size_t off = mb * stride_mb + sp * C; + PRAGMA_OMP_SIMD() + for (int c = 0; c < C; ++c) + output[off + c] = input[off + rev_transposed_[c]]; + }); + } else if (axis == 1 && one_of(tag, nchw, ncdhw)) { + parallel_nd(MB, C, [&](int mb, int c) { + const size_t output_off = mb * stride_mb + c * SP; + const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP; + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + output[output_off + sp] = input[input_off + sp]; + } + }); + } else { + auto dims = pd()->desc()->data_desc.dims; + auto ndims = pd()->desc()->data_desc.ndims; + const size_t outer_size = utils::array_product(dims, axis); + const size_t inner_size = utils::array_product(dims + axis + 1, + ndims - axis - 1); + const size_t dim = axis_size * inner_size; + + parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a, + size_t in) + { + const size_t off = ou * dim + in; + auto &o = output[data_d.off_l(off + a * inner_size)]; + o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)]; + }); + } +} + +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; + +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp new file mode 100644 index 0000000000..5e09a1a69b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SHUFFLE_HPP +#define CPU_REF_SHUFFLE_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_shuffle_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_shuffle_t : public cpu_primitive_t { + using shuffle_class = ref_shuffle_t; + + struct pd_t: public cpu_shuffle_pd_t { + using cpu_shuffle_pd_t::cpu_shuffle_pd_t; + + DECLARE_COMMON_PD_T("ref:any", shuffle_class); + + status_t init() { + using namespace format_tag; + + bool ok = true + && data_type_size + == types::data_type_size(data_md()->data_type); + if (!ok) return status::unimplemented; + + if (ndims() == 5) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc); + } else if (ndims() == 4) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nChw16c, nChw8c, nchw, nhwc); + } else + dat_tag_ = any; + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) { + const int axis_size = pd()->axis_size(); + const int group_size = pd()->group_size(); + const int transpose_row = pd()->is_fwd() ? group_size + : axis_size / group_size; + const int transpose_col = pd()->is_fwd() ? axis_size / group_size + : group_size; + rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64); + parallel_nd(transpose_col, transpose_row, [&](int i, int j) { + rev_transposed_[j * transpose_col + i] = i * transpose_row + j; + }); + } + + ~ref_shuffle_t() { free(rev_transposed_); } + + typedef typename typesize_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nCdhw16c: execute_(ctx); break; + case nChw16c: execute_(ctx); break; + case nCdhw8c: execute_(ctx); break; + case nChw8c: execute_(ctx); break; + case ncdhw: execute_(ctx); break; + case nchw: execute_(ctx); break; + case ndhwc: execute_(ctx); break; + case nhwc: execute_(ctx); break; + default: execute_(ctx); break; + } + return status::success; + } + +private: + template + void execute_(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int *rev_transposed_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp new file mode 100644 index 0000000000..36d5237f56 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_softmax.hpp" +#include "gemm/os_blas.hpp" + +#ifdef USE_MKL +#include "mkl_vml_functions.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +void ref_softmax_fwd_t::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + parallel_nd(outer_size_, [&](int ou) { + const data_t *src_data = src + ou * channels_; + data_t *dst_data = dst + ou * channels_; + data_t scalar = 0; + + _max(channels_, src_data, &scalar); + _sub(channels_, scalar, src_data, dst_data); + _exp(channels_, dst_data, dst_data); + _sum(channels_, dst_data, &scalar); + _scal(channels_, data_t(1)/scalar, dst_data); + }); +} + +template +void ref_softmax_fwd_t::execute_forward_generic( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + data_t space_max_val = 0, space_denom_val = 0; + data_t *space_max = &space_max_val, *space_denom = &space_denom_val; + if (inner_size_ > 1) { + using namespace memory_tracking::names; + space_max = scratchpad(ctx).template get(key_softmax_reduction); + space_denom = space_max + inner_size_; + } + + const memory_desc_wrapper data_d(pd()->src_md()); + const size_t dim = channels_ * inner_size_; + + for (int ou = 0; ou < outer_size_; ou++) { + utils::array_set(space_max, -FLT_MAX, inner_size_); + utils::array_set(space_denom, 0, inner_size_); + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_max[in] = nstl::max(space_max[in], src[off]); + } + } + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_denom[in] += dst[off] = exp(src[off] - space_max[in]); + } + } + + for (int c = 0; c < channels_; c++) { + for (int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + dst[off] /= space_denom[in]; + } + } + } +} + +template +void ref_softmax_fwd_t::_max(int n, const data_t *x, + data_t *max_data) const { +// Intel(R) C++ Compiler generates the maxps + shuffle pattern +// for the max search which works faster +#if !defined(__INTEL_COMPILER) + // The code below makes a compiler to generate maxps instruction + // rather than maxss, which is generated for the 'else' code path + auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); }; + auto min_wrapper = [](int a, int b) { return nstl::min(a, b); }; + + constexpr int unroll_factor = 32; + data_t max_values[unroll_factor]; + + if (n < unroll_factor) { + data_t max_val = x[0]; + for (int i = 1; i < n; i++) { + max_val = max_wrapper(max_val, x[i]); + } + max_data[0] = max_val; + return; + } + for (int i = 0; i < unroll_factor; i++) { + max_values[i] = x[i]; + } + for (int i = unroll_factor; i < n; i += unroll_factor) { + int offset = min_wrapper(i, n - unroll_factor); + for (int j = 0; j < unroll_factor; j++) { + max_values[j] = max_wrapper(max_values[j], x[offset + j]); + } + } + data_t max_val = max_values[0]; + for (int i = 1; i < unroll_factor; i++) { + max_val = max_wrapper(max_val, max_values[i]); + } + max_data[0] = max_val; +#else + max_data[0] = x[0]; + for (int c = 1; c < n; ++c) + max_data[0] = nstl::max(max_data[0], x[c]); +#endif +} + +template +void ref_softmax_fwd_t::_sub(int n, data_t alpha, const data_t *x, + data_t *y) const { + constexpr int unroll_factor = 32; + int tail = n % unroll_factor; + for (int i = 0; i < n - tail; i += unroll_factor) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < unroll_factor; j++) { + y[i + j] = x[i + j] - alpha; + } + } + PRAGMA_OMP_SIMD() + for (int i = n - tail; i < n; i++) { + y[i] = x[i] - alpha; + } +} + +template +void ref_softmax_fwd_t::_exp(int n, const data_t *a, + data_t *r) const { +#ifdef USE_MKL + if (data_type == data_type::f32) { + vsExp(n, a, r); + return; + } +#endif + parallel_nd(n, [&](int c) { r[c] = expf(a[c]); }); +} + +template +void ref_softmax_fwd_t::_sum(int n, const data_t *x, + data_t *sum_data) const { +#ifdef USE_CBLAS + // Here we are summing x's eg. e^z , which are positives + // so we can use BLAS ASUM + if (data_type == data_type::f32) { + sum_data[0] = cblas_sasum(n, x, 1); + return; + } +#endif + data_t tsum = static_cast(0); + PRAGMA_OMP_SIMD(reduction(+ : tsum)) + for (int c = 0; c < n; ++c) + tsum += x[c]; + sum_data[0] = tsum; +} + +template +void ref_softmax_fwd_t::_scal(int n, data_t alpha, data_t *x) const { +#ifdef USE_CBLAS + if (data_type == data_type::f32) { + cblas_sscal(n, alpha, x, 1); + return; + } +#endif + parallel_nd(n, [&](int c) { x[c] *= alpha; }); +} + +template struct ref_softmax_fwd_t; + + +// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW) +template +void ref_softmax_bwd_t::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + parallel_nd(outer_size_, [&](int ou) { + data_t sbr = 0; + size_t off = channels_*ou; + for (int c = 0; c < channels_; c++) { + size_t loff = off + c; + data_t ldata = dst[loff]; + sbr += diff_dst[loff]*ldata; + diff_src[loff] = ldata; + } + + for(int c=0; c < channels_ ; ++c) { + size_t loff = off + c; + diff_src[loff] *= (diff_dst[loff] - sbr); + } + }); +} + +template +void ref_softmax_bwd_t::execute_backward_generic( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_d(pd()->diff_src_md()); + const memory_desc_wrapper data_d(pd()->dst_md()); + + const size_t dim = channels_ * inner_size_; + + parallel_nd(outer_size_, [&](int ou) { + for (int in = 0; in < inner_size_; in++) { + data_t sbr = 0; + for (int c = 0; c < channels_; c++) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in); + sbr += diff_dst[off_diff] * dst[off_data]; + } + + for(int c=0; c < channels_ ; ++c) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in); + diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr); + } + } + }); +} + +template struct ref_softmax_bwd_t; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp new file mode 100644 index 0000000000..5cb74d8007 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp @@ -0,0 +1,186 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SOFTMAX_HPP +#define CPU_REF_SOFTMAX_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_softmax_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct ref_softmax_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_fwd_pd_t { + using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + const int inner_size = utils::array_product( + desc()->data_desc.dims + desc()->softmax_axis + 1, + desc()->data_desc.ndims - desc()->softmax_axis - 1); + + if (inner_size > 1) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_softmax_reduction, + sizeof(data_t) * 2 * inner_size); + } + } + }; + + ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { + auto ndims = pd()->desc()->data_desc.ndims; + auto dims = pd()->desc()->data_desc.dims; + auto axis = pd()->desc()->softmax_axis; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->src_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk) + if (data_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = inner_size_ == 1 && data_d.is_dense() + && no_axis_blocking + && data_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_forward_dense(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + + void _max(int n, const data_t *x, data_t *max_data) const; + void _sub(int n, data_t alpha, const data_t *x, data_t *y) const; + void _exp(int n, const data_t *a, data_t *r) const; + void _sum(int n, const data_t *x, data_t *sum_data) const; + void _scal(int n, data_t alpha, data_t *x) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + +template +struct ref_softmax_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_bwd_pd_t { + using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t); + + status_t init() { + bool ok = true + && !is_fwd() + && utils::everyone_is(data_type, + dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + }; + + ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) { + auto dims = pd()->desc()->diff_desc.dims; + auto axis = pd()->desc()->softmax_axis; + auto ndims = pd()->desc()->diff_desc.ndims; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->dst_md()); + const memory_desc_wrapper diff_d(pd()->diff_dst_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk) + if (diff_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = true + && inner_size_ == 1 + && diff_d == data_d + && diff_d.is_dense() + && no_axis_blocking + && diff_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp new file mode 100644 index 0000000000..3b2a75d99b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_SUM_HPP +#define REF_SUM_HPP + +#include "reorder_pd.hpp" + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_SUM_PD_T("ref:any", ref_sum_t); + + status_t init() { + bool ok = cpu_sum_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + primitive_attr_t attr; + attr.output_scales_.set(scales_[i]); + if (i != 0) attr.post_ops_.append_sum(1.0); + + reorder_pd_t *r_pd; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, dst_md()) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector reorder_pds_; + }; + + ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_sum_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp new file mode 100644 index 0000000000..537084db91 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Common for RNN and LSTM cell execution + */ +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +using namespace rnn_utils; + +template +rnn_cell_execution_sig( + (_ref_rnn_common_t::cell_execution)) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + if (rnn_postgemm_ != nullptr) + rnn_postgemm_->execute(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + else + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} +template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) { + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + /// bwd by data on the cell + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, + 0.0, diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + + /// bwd by weights on the cell + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + + if (!rnn.merge_gemm_iter) + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, + diff_w_iter_, rnn.diff_weights_iter_ld); + + /// bwd by bias we just accumulate diffs from the gates + gates_reduction(rnn, ws_gates_, diff_bias_); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp new file mode 100644 index 0000000000..e1a61d4c62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +#define AOC array_offset_calculator +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_[0]); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + // 1. gemm Wx[0-2],x + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + + // 2. gemm Wh[0-1],h + (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb, + rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + // 3. activation zt and rt + elemwise multiplication rt,ht-1 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); + } + }); + + // 4. gemm Wh[2],h~t + (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1], + rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0, + &(ws_gates(0, 2, 0)), rnn.gates_ws_ld); + + // 5. activation h~t + calculate ht + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) { + assert(!"GRU int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + // use state memory for intermediate computations + // TODO: use cell ws for that + float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0)); + float *hG1_ = dhG1_; + AOC dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld); + AOC hG1(hG1_, rnn.states_nld, rnn.states_ws_ld); + + // 1. calculate dG2, dG1, and part of dht-1 + // dG2^ = dh * (1 - G0) * (1 - G2^2) + // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) + // dht-1 (part) = dh * G0 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt + * one_m_square(ws_gates(i, 2, j)); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 0, j) = dG0; + ws_gates(i, 2, j) = dG2; + } + }); + + // 2. calculate intermediate d(hG1) + // d(hG1) = dG2 * W2h^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1], + rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0, + dhG1_, rnn.states_ws_ld); + + // 3. calculate dG1^ and part of dht-1 + // dG1^ = d(hG1) * h * G1 * (1 - G1) + // dht-1 (part) += d(hG1) * G1 + // h * G1 (required for dWh) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float G1 = ws_gates(i, 1, j); + diff_states_t_l(0, i, j) += dhG1(i, j) * G1; + ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1); + hG1(i, j) = G1 * h; + } + }); + + // 4. calculate diff weights + // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) + gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_iter_ld); + gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)), + rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0, + &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld); + + // 5. calculate diff states + // dht-1 += dG1 * W1h + dG0 * W0h + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, + (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0], + rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0, + diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + // dWx += [dG0 dG1 dG2] * [x] + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld); + } + + // 6. calculate diff bias + gates_reduction(rnn, ws_gates_, diff_bias_); +} +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp new file mode 100644 index 0000000000..8dea8c90a4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU with linear before reset + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_); + AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); + ws_gates(i, 0, j) = logistic_fwd( + ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd( + ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd( + ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + if (rnn.is_training) + ws_Wh_b(i, j) = Wh_b; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + // 1. calculate dG1 dG2 dG3 + // dG0 = (dht - G2) * dht * (1 - G0) * G0 + // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 + // dG2 = (1 - G0) * dht * (1 - G2*G2) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + float dG2 = (1.0f - ws_gates(i, 0, j)) + * one_m_square(ws_gates(i, 2, j)) * dHt; + float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 2, j) = dG2; + ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j); + ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; + ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + if (!rnn.merge_gemm_layer) { + // dx = dG * Wx^t + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + // dWx += dG^t * x + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + // dh += dGr * Wh^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld, + 1.0, diff_states_t_l_, rnn.states_ws_ld); + + // dWh += dGr^t * h + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_layer_ld); + + // db1-3 += e * dG + // db4 += e * (r * dG2) + gates_reduction(rnn, ws_gates_, diff_bias_); + + parallel_nd(rnn.dic, [&](int j) { + for (int i = 0; i < rnn.mb; i++) { + diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j); + } + }); +} + +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp new file mode 100644 index 0000000000..a15ba00d4c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "../simple_q10n.hpp" +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j)); + + float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j) + + ws_gates(i, 0, j) * ws_gates(i, 2, j); + states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) { + ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_u8_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + auto q_d = [&](float f) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + }; + + auto deq_w = [&](acc_data_t s, int gate, int j) { + return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ? + saturate(s) * (1.f / (weights_scales[0] * data_scale)) : + saturate(s) * (1.f / (weights_scales[gate * rnn.dic + j] + * data_scale)); + }; + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float G0 = logistic_fwd( + deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j)); + float G1 = logistic_fwd( + deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j)); + float G2 = tanh_fwd( + deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j)); + float G3 = logistic_fwd( + deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j)); + float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2; + states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp)); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Ct = c_states_t_l(i, j); + /// @todo save it in the workspace in fwd pass or recompute it to + /// save bw + float tanhCt = tanh_fwd(Ct); + // we have 2 incoming diffs on Ht + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dCt = diff_states_tp1_l(1, i, j) + + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; + + float dG1 = c_states_tm1_l(i, j) * dCt + * x_m_square(ws_gates(i, 1, j)); + float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); + float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); + float dG2 + = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); + + diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j); + + ws_gates(i, 0, j) = dG0; + ws_gates(i, 1, j) = dG1; + ws_gates(i, 2, j) = dG2; + ws_gates(i, 3, j) = dG3; + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp new file mode 100644 index 0000000000..4536e8dfad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution of Vanilla RNN + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_fwd(s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return relu_bwd(dd, s, alpha); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return tanh_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return dd * one_m_square(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return logistic_fwd(s); +} + +template <> +float activation( + float dd, float s, float alpha, float cliping) { + return dd * x_m_square(s); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; j++) { + const float h + = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); + ws_gates(i, 0, j) = states_t_l(i, j) = h; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) { + assert(!"VANILLA RNN int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; ++j) { + const float dH = diff_states_t_lp1(rnn.n_states, i, j) + + diff_states_tp1_l(0, i, j); + auto g = ws_gates(i, 0, j); + ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp new file mode 100644 index 0000000000..b39427caf9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp @@ -0,0 +1,191 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RNN_PD_HPP +#define CPU_RNN_PD_HPP + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "rnn_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "rnn_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { + using rnn_fwd_pd_t::rnn_fwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace data_type; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + /* Int8 is supported only for packed weights */ + data_type_t weights_iter_dt = weights_iter_md_.data_type; + data_type_t weights_layer_dt = weights_layer_md_.data_type; + ok = ok && IMPLICATION( + weights_iter_dt == s8, weights_iter_md_.format_kind + == format_kind::rnn_packed); + ok = ok && IMPLICATION( + weights_layer_dt == s8, weights_layer_md_.format_kind + == format_kind::rnn_packed); + + return ok ? status::success : status::unimplemented; + } +}; + +struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { + using rnn_bwd_pd_t::rnn_bwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + if (diff_src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); + if (diff_weights_layer_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); + } + if (diff_weights_iter_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); + } + if (diff_dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc)); + if (with_bias() && diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); + if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + ok = ok && is_blocked(diff_src_layer_md_, 3) + && is_blocked(diff_dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_), + is_blocked(diff_src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), + is_blocked(diff_dst_iter_md_, 5)); + + ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) + && rnn_utils::is_ldigo(&diff_weights_iter_md_); + ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_), + memory_desc_matches_tag(diff_bias_md_, ldgo)); + + return ok ? status::success : status::unimplemented; + } +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp new file mode 100644 index 0000000000..09445648aa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp @@ -0,0 +1,401 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "rnn_utils.hpp" +#include "../jit_generator.hpp" +#include "../jit_uni_eltwise.hpp" +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "mkldnn_thread.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_uni_rnn_postgemm_kernel : public jit_generator { + + typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_, + void *c_states_t_l_, void *c_states_tm1_l_); + + jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){} + + virtual void init() = 0; + +template + rnn_elemwise_sig(execute) { + rnn_utils::ws_gates_aoc ws_gates(rnn, ws_gates_); + rnn_utils::bias_aoc_t bias(rnn, bias_); + rnn_utils::ws_states_aoc states_t_l(rnn, states_t_l_); + rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + // Todo: add parallelization on dic for the batch 1 case + // Assumption: the kernel runs a loop on dic elements + parallel_nd(rnn.mb, [&](int i) { + auto b_ = &bias(0, 0); + auto g_ = &ws_gates(i, 0, 0); + auto s_tl_ = &states_t_l(i, 0); + auto c_tl_ = &c_states_t_l(i, 0); + auto c_tm1l_ = &c_states_tm1_l(i, 0); + kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_); + }); + } + +protected: + kernel_t kernel_; + const rnn_utils::rnn_conf_t &rnn_; + const primitive_attr_t *attr_; +}; + +template +struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd) + + typedef typename utils::conditional::type acc_data_t; + typedef typename utils::conditional, + jit_uni_eltwise_injector_f32>::type injector_t; + + jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr) + : jit_uni_rnn_postgemm_kernel(rnn, attr){} + + void init() override { + // we use rax for both constant tables as they use the same table + sigmoid_injector_ = new injector_t(this, + alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax); + tanh_injector_ = new injector_t(this, + alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax); + generate(); + kernel_ = (kernel_t) this->getCode(); + } + +protected: + injector_t *sigmoid_injector_; + injector_t *tanh_injector_; + + // register size in bytes + using Vmm = typename jit_uni_eltwise_injector_f32::Vmm; + size_t vlen = cpu_isa_traits::vlen; + size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen; + size_t cstate_dt_size = sizeof(float); + size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float); + size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float); + size_t qscale_dt_size = sizeof(float); + size_t bias_dt_size = sizeof(float); + + void generate() { + using namespace Xbyak; + + int mask = attr_->rnn_weights_qparams_.mask_; + float *weights_scales = attr_->rnn_weights_qparams_.scales_; + float data_scale = attr_->rnn_data_qparams_.scale_; + float data_shift = attr_->rnn_data_qparams_.shift_; + + // Labels declaration + Label vector_loop_start_label, vector_loop_end_label; + Label rem_loop_start_label, rem_loop_end_label; + Label table_label; + + // Register map + Reg64 loop_cnt(r11); // loop counter + Reg64 table_reg(rbx); // table is used for data scale and shifts + Reg64 weights_scales_reg(r13); + // We skip vmm0 as it can be used by the injector for masks on sse4.2 + Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7); + + // constant table map + Address dscale_off_addr = ptr[table_reg]; + Address dshift_off_addr = ptr[table_reg + vlen]; + Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen]; + Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits::vlen]; + + // quantize from float to u8 + auto q_d = [&](Vmm f, Vmm tmp_vmm) { + uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm); + uni_vmulps(f, f, dscale_off_addr); // apply scale + uni_vaddps(f, f, dshift_off_addr); // apply shift + uni_vcvtps2dq(f, f); // convert to int32 + uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16 + uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation + // Note that the results are interleaved by 128 bit chunks, so we need to merge them together + switch (vlen) { + case 64: { //avx512 + Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx()); + uni_vmovups(tmpz, zmm_perm_mask_addr); + vpermd(fz, tmpz, fz); + break; } + case 32: { //avx + Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx()); + uni_vmovups(tmpy, ymm_perm_mask_addr); + vpermd(fy, tmpy, fy); + break; } + case 16: // sse: nothing to do + break; + default: assert(!"Unsupported case"); + }; + }; + + auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) { + if (packed) + uni_vrcpps(tmp, s); + else + uni_vrcpss(tmp, s); // prevent divide by zero + // we add one Newton iteration + uni_vmulps(s, s, tmp); + uni_vmulps(s, s, tmp); // s <- s * tmp^2 + uni_vaddps(tmp, tmp, tmp); + uni_vsubps(tmp, tmp, s); + uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 + }; + + // dequantize from s32 to float + auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) { + // TODO: if mask is 0 precompute mul and inverse + if (mask == 0) + uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); + else + uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]); + uni_vcvtdq2ps(s, s); + uni_vmulps(tmp1, tmp1, dscale_off_addr); + fast_recip(tmp1, tmp2, packed); + uni_vmulps(s, s, tmp1); + }; + + // We start code generations here + preamble(); + + // extract addresses passed as parameter +#ifdef _WIN32 + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = r10; + // Here we cannot use rbp to have initial stack pointer so we + // use rsp and offset it with the size of pushed registers in + // preamble + mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]); +#else + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = abi_param5; +#endif + + // initialize registers with addresses and constants + mov(table_reg, table_label); + mov(weights_scales_reg, size_t(weights_scales)); + // both sigmoid and tanh use the same table so load address just once in rax + sigmoid_injector_->load_table_addr(); + + mov(loop_cnt, rnn_.dic * gate_dt_size); + cmp(loop_cnt, vlen); + jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + + L(vector_loop_start_label); + { + // load G0 G1 G2 G3 + uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true); + } + + // add biases + uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0.getIdx()); + sigmoid_injector_->compute_vector(G1.getIdx()); + tanh_injector_->compute_vector(G2.getIdx()); + sigmoid_injector_->compute_vector(G3.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0 * G2 + uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1_vmm, tmp1_vmm, G1); + uni_vfmadd231ps(tmp1_vmm, G0, G2); + uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1_vmm.getIdx()); + uni_vmulps(tmp1_vmm, tmp1_vmm, G3); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, vlen); + add(addr_bias_reg, vlen); + add(addr_states_t_l_reg, vlen_dst); + add(addr_c_states_tm1_l_reg, vlen); + add(addr_c_states_t_l_reg, vlen); + if (mask != 0) + add(weights_scales_reg, vlen); + + // increment loop counter + sub(loop_cnt, vlen); + cmp(loop_cnt, vlen); + jge(vector_loop_start_label); + } + L(vector_loop_end_label); + + cmp(loop_cnt, 0); + je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + // Same code as above, we just use movuss for accessing inputs + // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar + L(rem_loop_start_label); + { + // remaping registers to Xmms + Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx()); + Xmm tmp1s_vmm(tmp1_vmm.getIdx()); + + // load G0 G1 G2 G3 + uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false); + } + + // add biases + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0s, G0s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1s, G1s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2s, G2s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3s, G3s, tmp1s_vmm); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0s.getIdx()); + sigmoid_injector_->compute_vector(G1s.getIdx()); + tanh_injector_->compute_vector(G2s.getIdx()); + sigmoid_injector_->compute_vector(G3s.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0s * G2 + uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s); + uni_vfmadd231ps(tmp1s_vmm, G0s, G2s); + uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1s_vmm.getIdx()); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, gate_dt_size); + add(addr_bias_reg, bias_dt_size); + add(addr_states_t_l_reg, hstate_dt_size); + add(addr_c_states_tm1_l_reg, cstate_dt_size); + add(addr_c_states_t_l_reg, cstate_dt_size); + if (mask != 0) + add(weights_scales_reg, qscale_dt_size); + + // increment loop counter + sub(loop_cnt, gate_dt_size); + cmp(loop_cnt, 0); + jg(rem_loop_start_label); + + } + L(rem_loop_end_label); + + postamble(); + + // Again, only one table is needed and shared between sigmoid and tanh + sigmoid_injector_->prepare_table(false); + tanh_injector_->prepare_table(true); + + L(table_label); + { + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale)); + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift)); + // perm mask for ymm + dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7); + // perm mask for zmm + dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7); + dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14); + } + } + +}; + +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; + +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +template struct jit_uni_lstm_postgemm_kernel_fwd; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp new file mode 100644 index 0000000000..ead536816c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp @@ -0,0 +1,788 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + General architecture + + for diff states, we have n_states + 1 as we have n_states diff + to propagate to the previous iteration and 1 states to propagate + to the previous layer + index 0 is dh for cell(t-1, l) to consume + index 1 is dc for cell(t-1, l) to consume + index 2 is dh for cell(t, l-1) to consume + this indexing enables to have the same indexing for states in elemwise + function + only the cell execution function should be impacted + + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "../gemm/gemm.hpp" +#include "../simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::memory_tracking::names; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template +void _ref_rnn_common_t::gates_reduction( + const rnn_conf_t &rnn, const acc_data_t *ws_gates_, + float *diff_bias_) const { + auto body = [&](int i, int k) { + for (int j = 0; j < rnn.mb; j++) + diff_bias_[i * rnn.dic + k] + += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; + }; + + // @todo block k on simd-width +#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ + /* icc 17.0 has a problem with simd collapse */ \ + && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) +#pragma omp parallel for simd collapse(2) + for (int i = 0; i < rnn.n_gates; i++) + for (int k = 0; k < rnn.dic; k++) + body(i, k); +#else + parallel_nd(rnn.n_gates, rnn.dic, body); +#endif +} + +template +rnn_gemm_sig((_ref_rnn_common_t::gemm)) { + assert(ldA * ldB * ldC != 0); + extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, + &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { + assert(!"non packed gemm is disabled for int8"); +} + +template +rnn_gemm_sig((_ref_rnn_common_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + assert(transA == 'N'); + cblas_sgemm_compute(CblasColMajor, CblasPacked, + (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, + ldB, beta, c_, ldC); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + int8_t offseta = 0, offsetb = 0; + int32_t offsetc = 0; + cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, + CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, + ldB, offsetb, beta, c_, ldC, &offsetc); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +//*************** Grid computations strategy: linear ***************// +template +rnn_grid_execution_sig( + (_ref_rnn_common_t::linear_execution)) { + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, + rnn.states_nld * rnn.states_ws_ld); + AOC ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, + rnn.gates_nld * rnn.gates_ws_ld); + AOC weights_input( + weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); + AOC weights_states( + weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); + AOC bias( + bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC diff_weights_layer(diff_weights_layer_, rnn.n_layer, + rnn.n_dir, + rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); + AOC diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, + rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); + AOC diff_bias( + diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + AOC ws_grid( + ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); + + // We run the grid of computation + for (int dir = 0; dir < rnn.n_dir; dir++) { + for (int j = 0; j < rnn.n_layer; j++) { + int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; + + if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, + rnn.mb * rnn.n_iter, rnn.slc, 1.0, + weights_input(lay, dir, 0), rnn.weights_iter_ld, + &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, + &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); + } + + for (int i = 0; i < rnn.n_iter; i++) { + int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; + (this->*cell_func)(rnn, + &(ws_states(lay + 1, dir, iter + 1, 0)), + &(ws_c_states(lay + 1, dir, iter + 1, 0)), + &(ws_diff_states(lay, dir, 0, iter, 0)), + &(weights_input(lay, dir, 0)), + &(weights_states(lay, dir, 0)), + &(bias(lay, dir, 0)), + &(ws_states(lay, dir, iter + 1, 0)), + &(ws_states(lay + 1, dir, iter, 0)), + &(ws_c_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay + 1, dir, 0, iter, 0)), + &(ws_diff_states(lay, dir, 0, iter + 1, 0)), + &(diff_weights_layer(lay, dir, 0)), + &(diff_weights_iter(lay, dir, 0)), + &(diff_bias(lay, dir, 0)), + &(ws_gates(lay, dir, iter, 0)), + &(ws_grid(lay, dir, iter, 0)), + ws_cell_); + } + + if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, + rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), + rnn.weights_layer_ld, + (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, 0.0, + (acc_data_t *)(&(ws_diff_states( + lay, dir, rnn.n_states, 0, 0))), + rnn.states_ws_ld); + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay, dir, 1, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), + rnn.diff_weights_layer_ld); + } + if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), + rnn.diff_weights_iter_ld); + } + } + } +} + +//********* GRID computations strategy: utility functions **********// + +template +void _ref_rnn_common_t::copy_init_layer( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, + const float *__restrict diff_dst_layer_) const { + + AOC ws_states( + ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto xt_d = memory_desc_wrapper(pd()->src_md(0)); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto xxt = xt_ + xt_d.blk_off(it, b); + src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); + src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); + if (rnn.exec_dir != r2l) + for (int c = 0; c < rnn.slc; c++) + ws_l2r_ptr[c] = xxt[c]; + if (rnn.exec_dir != l2r) + for (int c = 0; c < rnn.slc; c++) + ws_r2l_ptr[c] = xxt[c]; + }); +} + +template <> +void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, + const float *diff_dst_layer_) const { + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); + + switch (rnn.exec_dir) { + case bi_concat: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[rnn.dic + s]; + } + }); + break; + case bi_sum: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case l2r: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case r2l: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x = diff_dst_layer_ + + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + default: assert(!"Unsupported direction"); break; + } +} + +/* For int8 configuration, input iteration states may be of types f32 or u8 + * Internally h_state is always stored in u8 and c_state is always stored in f32 + * If input states are of type u8 then h state is copied and c state is dequantized + * If input states are of type f32 then h state is quantized and c_state is copied + * */ +template +template +void _ref_rnn_common_t::copy_init_iter( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_c_states_, float *__restrict ws_diff_states_, + const input_data_t *__restrict firstit_states_, + const float *__restrict diff_dst_iter_) const { + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](input_data_t f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + } else + return (src_data_t)f; + }; + + const bool dequantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::u8; + auto maybe_deq = [&](input_data_t s) { + if (dequantize) + return (((float)s - data_shift) / data_scale); + else + return (float)s; + }; + auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); + if (firstit_states_) { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.sic; s++) + ws_states(lay + 1, dir, 0, b, s) = maybe_q( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 0, b, s)]); + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.sic; s++) + ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 1, b, s)]); + }); + } else { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int j = 0; j < rnn.sic; j++) { + ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; + ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; + } + }); + } +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter_) const { + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); + if (diff_dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + array_copy(&(ws_diff_states( + lay, dir, state, rnn.n_iter, b, 0)), + diff_dst_iter_ + + diff_dst_iter_d.blk_off( + lay, dir, state, b), + rnn.dic); + }); + } else { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int i) { + for (int j = 0; j < rnn.dic; j++) + ws_diff_states(lay, dir, state, rnn.n_iter, i, j) + = 0.0f; + }); + } +} + +template +template +void _ref_rnn_common_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + + auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float shift = (pd()->attr()->rnn_data_qparams_.shift_); + float scale = (pd()->attr()->rnn_data_qparams_.scale_); + + const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (dst_data_t)(((float)s - shift) / scale); + else + return (dst_data_t)s; + }; + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + if (rnn.exec_dir != r2l) { + for (int s = 0; s < rnn.dic; s++) { + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); + } + dir = 1; + } + if (rnn.exec_dir != l2r) { + for (int s = 0; s < rnn.dic; s++) + switch (rnn.exec_dir) { + case bi_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] + += maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + break; + default: + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + } + } + }); +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + for (int s = 0; s < rnn.slc; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, + b, dir * rnn.slc + s); + float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); + if (rnn.n_dir - 1) + res += ws_diff_states( + 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); + dst_addr[0] = res; + } + }); +} + +template +template +void _ref_rnn_common_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); + AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::u8 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](float f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0()(qf); + } else + return (output_data_t)f; + }; + + const bool dequantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (output_data_t)(((float)s - data_shift) / data_scale); + else + return (output_data_t)s; + }; + if (dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, + [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] + = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); + } + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] + = maybe_q(ws_c_states( + lay + 1, dir, rnn.n_iter, b, s)); + } + }); + } +} + +template <> +template +void ref_rnn_bwd_f32_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); + AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + if (diff_src_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < rnn.sic; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, state, 0, b, s); + } + }); + } +} + +template +rnn_bias_prepare_sig((_ref_rnn_common_t::bias_prepare)) { + /* Original set of bias provided by the user */ + AOC b( + b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + /* Array of pointers initialized in packing */ + AOC bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC scratch_bias( + scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + + if (rnn.copy_bias) { + parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, + [&](size_t i) { scratch_bias_[i] = b_[i]; }); + } + + for (int i = 0; i < rnn.n_layer; i++) { + for (int d = 0; d < rnn.n_dir; d++) { + int offset_bias = 0; + for (int p = 0; p < rnn.n_parts_bias; p++) { + bias(i, d, p) = rnn.copy_bias + ? (float *) &scratch_bias(i, d, offset_bias) + : (float *) &b(i, d, offset_bias); + offset_bias += rnn.parts_bias[p] * rnn.dic; + } + } + } + +} + +template +rnn_bias_finalize_sig( + (_ref_rnn_common_t::bias_finalize)) { + if (rnn.dt_conf != all_f32) { + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; + for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) + for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { + size_t off = i * rnn.n_bias * rnn.dic + j; + float weights_scale + = scale_per_oc ? weights_scales[j] : weights_scales[0]; + scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) + * data_shift / (weights_scale * data_scale); + } + } +} + +template +rnn_weights_assign_sig((_ref_rnn_common_t::assign_packed_weights)) { + assert(md->format_kind == format_kind::rnn_packed); + const auto packed_desc = md->format_desc.rnn_packed_desc; + AOC weights(weights_, + rnn.n_layer, rnn.n_dir, packed_desc.n_parts); + + size_t offset_packed = 0; + for (int l = 0; l < rnn.n_layer; l++) + for (int d = 0; d < rnn.n_dir; d++) { + for (int p = 0; p < packed_desc.n_parts; p++) { + weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; + offset_packed + += packed_desc.part_pack_size[p] / sizeof(weights_data_t); + } + } +} + +template +rnn_weights_assign_sig( + (_ref_rnn_common_t::assign_weights)) { + assert(md->format_kind == format_kind::blocked); + const auto &blk = md->format_desc.blocking; + /* Original set of weights provided by the user */ + AOC w(w_, + rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); + /* Array of pointers for each part of weights */ + AOC weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); + + for (int i = 0; i < rnn.n_layer; i++) + for (int d = 0; d < rnn.n_dir; d++) { + size_t offset_weights = 0; + for (int p = 0; p < n_parts; p++) { + weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); + offset_weights += gates_per_part[p] * blk.strides[3]; + } + } +} + +//********************* Execution function *********************// +template +void _ref_rnn_common_t::execute_( + const exec_ctx_t &ctx) const { + const rnn_conf_t &rnn = this->pd()->rnn_; + auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); + auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); + auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); + auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + + auto dst_last_layer = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); + auto dst_last_iter = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); + + auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); + auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); + + auto w_layer = reinterpret_cast(layer_weights_n_comp); + auto w_iter = reinterpret_cast(iter_weights_n_comp); + auto w_iter_comp = reinterpret_cast( + iter_weights_n_comp + rnn.weights_iter_comp_offset); + auto w_layer_comp = reinterpret_cast( + layer_weights_n_comp + rnn.weights_layer_comp_offset); + + auto scratchpad = this->scratchpad(ctx); + + auto ptr_wei_layer + = scratchpad.template get(key_rnn_ptrs_wei_layer); + auto ptr_wei_iter + = scratchpad.template get(key_rnn_ptrs_wei_iter); + auto ptr_bias = + scratchpad.template get(key_rnn_ptrs_bia); + + // fetchihg buffers from the workspace + // if no workspace was provided we use the scratchpad + char *scratch_ptr = scratchpad.template get(key_rnn_space); + char *ws_ptr = nullptr; + if (rnn.use_workspace) + ws_ptr = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) + : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); + + char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; + acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); + src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); + float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); + float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); + float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); + float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); + + auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); + auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); + + auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); + auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + // Fetching extra buffers from scratchpad + float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); + + // initialize diff_states to 0 + if (aprop == prop_kind::backward) + array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); + + /* Pack(if using packed gemm API) or copy(if input arrays have bad leading + * dimension */ + (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); + + (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), + rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, + rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, + rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, + ptr_bias, bias, ws_bias); + (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), + rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, + rnn.n_parts_weights_layer, rnn.parts_weights_layer, + rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, + bias, ws_bias); + + (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); + + // we first need to copy the initial states and input into ws + copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const float *)states, diff_dst_iter); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const uint8_t *)states, diff_dst_iter); + else + assert(!"unimplemented"); + + // run the execution on the grid + (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, + ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, + diff_weights_layer, diff_weights_iter, diff_bias); + + // Finally we copy the results to the result buffers + if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, + ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) + copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, + ws_states, ws_diff_states); + else + assert(!"unimplemented"); + + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else + assert(!"unimplemented"); +}; + +/* Fix for MSVS warning C4661 */ +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); + +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; +template struct _ref_rnn_common_t; + +#undef AOC +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp new file mode 100644 index 0000000000..6f449a9016 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp @@ -0,0 +1,328 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_RNN_HPP +#define CPU_REF_RNN_HPP + +#include + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "../cpu_isa_traits.hpp" +#include "../gemm/os_blas.hpp" + +#include "cpu_rnn_pd.hpp" +#include "../cpu_primitive.hpp" +#include "rnn_utils.hpp" +#include "jit_uni_rnn_postgemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +float activation(float s, float alpha, float cliping, float dd); + +template +struct _ref_rnn_common_t : public cpu_primitive_t { + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type weights_data_t; + typedef typename utils::conditional::type acc_data_t; + + using class_name = _ref_rnn_common_t; + + typedef rnn_elemwise_sig((class_name::*elemwise_f)); + typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); + typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); + + typedef rnn_gemm_sig((class_name::*gemm_t)); + typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); + typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); + typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); + + using base_pd_t = + typename utils::conditional::type; + + struct pd_t : public base_pd_t { + using base_pd_t::base_pd_t; + + DECLARE_COMMON_PD_T("ref:any", class_name); + + status_t init() { + using namespace prop_kind; + using namespace utils; + using namespace format_tag; + using namespace rnn_utils; + const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; + + data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; + data_type_t weights_iter_dt + = this->desc()->weights_iter_desc.data_type; + data_type_t weights_layer_dt + = this->desc()->weights_layer_desc.data_type; + + bool ok = true + && one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset) + && IMPLICATION(aprop == prop_kind::forward, + one_of(this->desc()->prop_kind, forward_training, + forward_inference)) + && IMPLICATION(aprop == backward, + one_of(this->desc()->prop_kind, backward)) + && src_layer_dt == src_type + && everyone_is( + weights_type, weights_iter_dt, weights_layer_dt) + && this->set_default_params() == status::success + && this->with_bias(); + if (!ok) + return status::unimplemented; + + init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1), + this->weights_md(0), this->weights_md(1), this->dst_md(0)); + + if (rnn_.dt_conf == all_f32) + ok = ok && this->attr()->has_default_values(); + + // Set weights descriptors to desired format + memory_desc_t new_weights_layer_md = *this->weights_md(0); + CHECK(set_expected_desc(rnn_, new_weights_layer_md, false)); + if (this->weights_layer_md_.format_kind == format_kind::any) { + this->weights_layer_md_ = new_weights_layer_md; + } else if (this->weights_layer_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_layer_md_ != new_weights_layer_md) + return status::unimplemented; + } + + memory_desc_t new_weights_iter_md = *this->weights_md(1); + CHECK(set_expected_desc(rnn_, new_weights_iter_md, true)); + if (this->weights_iter_md_.format_kind == format_kind::any) { + this->weights_iter_md_ = new_weights_iter_md; + } else if (this->weights_iter_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_iter_md_ != new_weights_iter_md) + return status::unimplemented; + } + + CHECK(this->check_layout_consistency()); + + set_conf(rnn_, *this->desc(), this->weights_md(0), + this->weights_md(1), this->diff_weights_md(0), + this->diff_weights_md(1)); + + size_t scratchpad_sz{0}, ws_sz{0}; + get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); + + // initialize the workspace if needed + if (rnn_.is_training) { + dims_t ws_dims = { (int)ws_sz }; + mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims, + data_type::u8, format_tag::x); + } + + init_scratchpad(scratchpad_sz); + + return status::success; + } + + rnn_utils::rnn_conf_t rnn_; + + private: + void init_scratchpad(size_t scratchpad_sz) { + using namespace memory_tracking::names; + auto scratchpad = this->scratchpad_registry().registrar(); + scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096); + + int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1; + int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; + scratchpad.book(key_rnn_ptrs_wei_layer, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_wei_iter, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_bia, + sizeof(float *) * ptr_wei_sz); + } + }; + + _ref_rnn_common_t(const pd_t *apd) + : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) { + /// @todo set max_feature_size assuming that we limit the number of + /// iterations and layer to one if slc != dic and sic != dic + /// respectively + + bias_preparation_func = &class_name::bias_prepare; + bias_finalization_func = &class_name::bias_finalize; + + auto set_gemm_funcs + = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) { + if (packed_gemm) { + g = &class_name::packed_gemm; + a = &class_name::assign_packed_weights; + } else { + g = &class_name::gemm; + a = &class_name::assign_weights; + } + }; + set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, + weights_iter_assign_func); + + set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, + weights_layer_assign_func); + + switch (pd()->cell_kind()) { + case alg_kind::vanilla_lstm: + cell_func = &class_name::cell_execution; + if (aprop == prop_kind::forward) { + if (mayiuse(avx512_core)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + else if (mayiuse(avx2)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + else if (mayiuse(sse42)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( + pd()->rnn_, pd()->attr()); + assert(rnn_postgemm_ != nullptr); + rnn_postgemm_->init(); + } + elemwise_func = &class_name::lstm_elemwise; + break; + case alg_kind::vanilla_rnn: // @todo switch on cell kind + cell_func = &class_name::cell_execution; + elemwise_func = &class_name::rnn_elemwise; + switch (pd()->activation_kind()) { + case alg_kind::eltwise_relu: + activation_func = &activation; + break; + case alg_kind::eltwise_tanh: + activation_func = &activation; + break; + case alg_kind::eltwise_logistic: + activation_func = &activation; + break; + default: break; + } + break; + case alg_kind::vanilla_gru: + cell_func = &class_name::cell_execution_gru; + break; + case alg_kind::gru_linear_before_reset: + cell_func = &class_name::cell_execution_gru_lbr; + elemwise_func = &class_name::gru_lbr_elemwise; + break; + default: break; + } + + grid_computation = &class_name::linear_execution; + + size_t scratchpad_size, workspace_size; + rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_, + ws_c_states_offset_, ws_diff_states_offset_, + ws_grid_comp_offset_, ws_cell_comp_offset_, + ws_bias_offset_, scratchpad_size, workspace_size); + } + + ~_ref_rnn_common_t() {} + + // typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_(ctx); + return status::success; + } + +private: + void execute_(const exec_ctx_t &ctx) const; + rnn_grid_execution_sig(linear_execution); + rnn_cell_execution_sig(cell_execution); + rnn_cell_execution_sig(cell_execution_gru); + rnn_cell_execution_sig(cell_execution_gru_lbr); + rnn_elemwise_sig(rnn_elemwise); + rnn_elemwise_sig(lstm_elemwise); + rnn_elemwise_sig(gru_lbr_elemwise); + rnn_gemm_sig(gemm); + rnn_gemm_sig(packed_gemm); + rnn_bias_prepare_sig(bias_prepare); + rnn_bias_finalize_sig(bias_finalize); + rnn_weights_assign_sig(assign_weights); + rnn_weights_assign_sig(assign_packed_weights); + + float (*activation_func)(float dd, float s, float alpha, float cliping); + + void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, + const src_data_t *xt_, const float *diff_dst_layer) const; + + template + void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter) const; + + template + void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, + dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const; + + template + void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, + output_data_t *dst_iter_, float *diff_src_iter, + const src_data_t *ws_states_, float *ws_c_states, + const float *ws_diff_states_) const; + + void gates_reduction(const rnn_utils::rnn_conf_t &rnn, + const acc_data_t *ws_gates_, float *diff_bias_) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + size_t ws_gates_offset_; + size_t ws_states_offset_; + size_t ws_c_states_offset_; + size_t ws_bias_offset_; + size_t ws_diff_states_offset_; + size_t ws_grid_comp_offset_; + size_t ws_cell_comp_offset_; + jit_uni_rnn_postgemm_kernel *rnn_postgemm_; + + grid_execution_f grid_computation; + cell_execution_f cell_func; + + bias_prepare_t bias_preparation_func; + bias_finalize_t bias_finalization_func; + weights_assign_t weights_layer_assign_func; + weights_assign_t weights_iter_assign_func; + + gemm_t gemm_layer_func; + gemm_t gemm_iter_func; + elemwise_f elemwise_func; +}; + +using ref_rnn_fwd_f32_t = _ref_rnn_common_t; +using ref_rnn_bwd_f32_t = _ref_rnn_common_t; +using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t; +} +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp new file mode 100644 index 0000000000..597c63e3f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp @@ -0,0 +1,380 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_RNN_REORDERS_HPP +#define CPU_RNN_REORDERS_HPP + +#include + +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "simple_q10n.hpp" +#include "cpu_reorder_pd.hpp" +#include "../gemm/os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct rnn_data_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) + && od == id; + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign(*reorder_pd, _pd); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + + rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const size_t nelems = input_d.nelems(); + const float scale = pd()->attr()->rnn_data_qparams_.scale_; + const float shift = pd()->attr()->rnn_data_qparams_.shift_; + + parallel_nd(nelems, [&](size_t i) { + float in = (float)input[input_d.off_l(i)] * scale + shift; + output[output_d.off_l(i)] = qz_a1b0()(in); + }); + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template +struct rnn_weights_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && od.format_kind() == format_kind::rnn_packed + && od.rnn_packed_desc().format == mkldnn_ldigo_p + && od.rnn_packed_desc().n_parts == 1 + && attr != nullptr; + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + _pd->itag_ = itag; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + format_tag_t itag_; + + private: + void init_scratchpad() { + const memory_desc_wrapper id(src_md()); + const size_t nelems = id.nelems(); + const auto &dims = id.dims(); + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + size_t quantization_size = sizeof(int8_t) * nelems; + size_t reduction_size = itag_ == ldigo + ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] + * dims[1] * dims[3] * dims[4] + : 0; + scratchpad.book( + key_reorder_rnn_weights_quantization, quantization_size); + scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + + /* Quantize input & compute compensation */ + auto quantized = (int8_t * __restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_rnn_weights_quantization); + auto reduction = (int32_t * __restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_rnn_weights_reduction); + float *comp = reinterpret_cast( + output + output_d.rnn_packed_desc().offset_compensation); + const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; + const int mask = pd()->attr()->rnn_weights_qparams_.mask_; + + if (is_igo) { + int nthr = mkldnn_get_max_threads(); + int LD_nthr = nstl::min(L * D, nthr); + int I_nthr = nstl::min(I, nthr / LD_nthr); + parallel(nthr, [&](const int ithr, const int nthr) { + int LD_ithr = -1, LD_s = -1, LD_e = -1; + int I_ithr = -1, I_s = -1, I_e = -1; + if (ithr < LD_nthr * I_nthr) { + LD_ithr = ithr % LD_nthr; + I_ithr = ithr / LD_nthr; + balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); + balance211(I, I_nthr, I_ithr, I_s, I_e); + } + int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; + for (int ld = LD_s; ld < LD_e; ld++) { + for (int go = 0; go < G * O; go++) + comp_ithr[ld * G * O + go] = 0; + for (int i = I_s; i < I_e; i++) { + PRAGMA_OMP_SIMD() + for (int go = 0; go < G * O; go++) { + const float s = scales[(mask == 0) ? 0 : go]; + int8_t q = qz_b0()( + input[ld * I * G * O + i * G * O + go], s); + quantized[ld * I * G * O + i * G * O + go] + = (int32_t)q; + comp_ithr[ld * G * O + go] += (int32_t)q; + } + } + } + }); + parallel_nd(L * D * G * O, + [&](int s) { comp[s] = saturate(reduction[s]); }); + for (int i = 1; i < I_nthr; i++) { + parallel_nd(L * D * G * O, [&](int s) { + comp[s] += saturate( + reduction[i * L * D * G * O + s]); + }); + } + } else { + parallel_nd(L * D, G * O, [&](int ld, int go) { + int32_t compensation = 0; + const float s = scales[(mask == 0) ? 0 : go]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < I; i++) { + int8_t q = qz_b0()( + input[ld * G * O * I + go * I + i], s); + compensation += (int32_t)q; + quantized[ld * G * O * I + go * I + i] = q; + } + comp[ld * G * O + go] = saturate(compensation); + }); + } + + /* Pack */ + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + int n_parts = output_d.rnn_packed_desc().n_parts; + const size_t *size_packed_cell + = output_d.rnn_packed_desc().part_pack_size; + const int *parts = output_d.rnn_packed_desc().parts; + const int n = output_d.rnn_packed_desc().n; + char *to_pack = output; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = parts[p] * O; + int k_p = I; + cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, + is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, + &quantized[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, g, 0, 0)], + is_igo ? G * O : I, to_pack); + to_pack += size_packed_cell[p]; + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <> +struct rnn_weights_reorder_t + : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == data_type::f32 + && od.data_type() == data_type::f32 + && od.format_kind() == format_kind::rnn_packed + && utils::one_of(od.rnn_packed_desc().format, + mkldnn_ldigo_p, mkldnn_ldgoi_p) + && attr->has_default_values(); + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->itag_ = itag; + return safe_ptr_assign(*reorder_pd, _pd); + } + + format_tag_t itag_; + }; + +private: + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + /* Pack */ + bool cross_case = false + || (pd()->itag_ == format_tag::ldigo + && rnn_pdata.format == mkldnn_ldgoi_p) + || (pd()->itag_ == format_tag::ldgoi + && rnn_pdata.format == mkldnn_ldigo_p); + auto trans = cross_case ? CblasTrans : CblasNoTrans; + int n_parts = rnn_pdata.n_parts; + const size_t *size_packed_cell = rnn_pdata.part_pack_size; + const int *parts = rnn_pdata.parts; + const int n = rnn_pdata.n; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = is_igo ? parts[p] * O : I; + int k_p = is_igo ? I : parts[p] * O; + int ld = is_igo ? G * O : I; + cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, + k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, 0, g, 0)], + ld, output); + output += size_packed_cell[p] / sizeof(float); + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp new file mode 100644 index 0000000000..1d60415cbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "rnn_utils.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace rnn_utils; +using namespace format_tag; +using namespace rnn_packed_format; +using namespace data_type; + +bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1 + && str[3] == dims[4] && str[1] == str[2] * dims[2] + && str[0] == str[1] * dims[1]; +}; + +bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1 + && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3] + && str[0] == str[1] * dims[1]; +}; + +void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d) { + rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + rnn.is_training = utils::one_of( + rd.prop_kind, prop_kind::forward_training, prop_kind::backward); + rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset; + + switch (rd.direction) { + case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break; + case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break; + case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break; + case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break; + default: break; + } + + if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(), + weights_layer_d.data_type())) + rnn.dt_conf = all_f32; + else if (dst_layer_d.data_type() == u8) { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8u8; + else + rnn.dt_conf = f32u8f32u8; + } else { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8f32; + else + rnn.dt_conf = f32u8f32f32; + } + + rnn.n_layer = weights_layer_d.dims()[0]; + rnn.n_iter = src_layer_d.dims()[0]; + rnn.n_dir = weights_layer_d.dims()[1]; + rnn.n_gates = weights_layer_d.dims()[3]; + rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc); + rnn.n_bias = rnn.n_gates + rnn.is_lbr; + rnn.mb = src_layer_d.dims()[1]; + rnn.sic = weights_iter_d.dims()[2]; + rnn.slc = weights_layer_d.dims()[2]; + rnn.dic = weights_layer_d.dims()[4]; + rnn.dlc = dst_layer_d.dims()[2]; + + rnn.gates_ld = rnn.dic * rnn.n_gates; + rnn.gates_nld = rnn.mb; + rnn.states_nld = rnn.mb; + + /* Set the correct number of weights parts */ + bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru; + rnn.n_parts_weights_layer = 1; + rnn.parts_weights_layer[0] = rnn.n_gates; + rnn.parts_weights_layer[1] = 0; + + rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1; + rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates; + rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0; + + rnn.n_parts_bias = 1; + rnn.parts_bias[0] = rnn.n_bias; + rnn.parts_bias[1] = 0; + + /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas + * and if to mergre gemm across iterations */ + bool is_int8 = rnn.dt_conf != all_f32; + rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) + || is_int8; + bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset); + rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8; + bool is_inference = !rnn.is_training; + + rnn.use_jit_gemm = !mayiuse(avx512_mic) + && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) + || (rnn.is_training && rnn.dic < 500)); + + /* Decide to copy bias */ + rnn.copy_bias = rnn.dt_conf != all_f32; + +#if USE_MKL_PACKED_GEMM + rnn.use_layer_packed_gemm + = (weights_layer_d.format_kind() == format_kind::any + && rnn.slc > 760 && rnn.dic > 760 && is_inference) + || is_int8; // packed gemm is the only supported option for int8 + rnn.use_iter_packed_gemm + = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760 + && rnn.dic > 760 && is_inference) + || is_int8; +#else + rnn.use_layer_packed_gemm = false; + rnn.use_iter_packed_gemm = false; +#endif + + /* Set packed gemm sizes */ + if (rnn.use_layer_packed_gemm) { + rnn.weights_layer_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_layer; p++) { + int m_p = rnn.is_fwd + ? (rnn.parts_weights_layer[p] * rnn.dic) + : rnn.slc; + int k_p = rnn.is_fwd + ? rnn.slc + : (rnn.parts_weights_layer[p] * rnn.dic); + int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_layer_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_layer_pack_size[p] = 0; +#endif + rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_layer_pack_size[p]; + } + rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size; + rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float); + } + + if (rnn.use_iter_packed_gemm) { + rnn.weights_iter_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_iter; p++) { + int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) : + rnn.sic; + int k_p = rnn.is_fwd ? rnn.sic : + (rnn.parts_weights_iter[p] * rnn.dic); + int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_iter_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_iter_pack_size[p] = 0; +#endif + rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_iter_pack_size[p]; + } + rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size; + rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float); + } + +} + +void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d) { + + /* Set leading dimensions for input weights arrays depending on input format + */ + rnn.weights_layer_is_packed + = weights_layer_d.format_kind() == format_kind::rnn_packed; + rnn.weights_iter_is_packed + = weights_iter_d.format_kind() == format_kind::rnn_packed; + + auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) { + ld = 0; nld = 0; + if (md.is_blocking_desc()) { + if (is_ldigo(md)) { + ld = (int)md.blocking_desc().strides[2]; + nld = md.dims()[2]; + } else if (is_ldgoi(md)) { + ld = (int)md.blocking_desc().strides[4]; + nld = md.dims()[3] * md.dims()[4]; + } else + assert(!"unsupported weights format"); + } + }; + set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); + set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); + if (!rnn.is_fwd) { + set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, + rnn.diff_weights_layer_nld); + set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, + rnn.diff_weights_iter_nld); + } + + int sizeof_states_dt + = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t); + rnn.states_ws_ld + = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)), + sizeof_states_dt); + rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float)); + + /* Set workspace sizes to store: + * states to copmute a pass + * diff states to copmute bwd pass (training only) + * intermediate results from the gates + */ + rnn.use_workspace = rnn.is_training; + rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir + * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt; + bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm; + rnn.ws_c_states_size = is_lstm + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb + * rnn.states_ws_ld * sizeof(float) + : 0; + rnn.ws_diff_states_size = rnn.is_training + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) + * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld + * sizeof(float) + : (size_t)0; + rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb + * rnn.gates_ws_ld * sizeof(float); + + /* set other sizes */ + rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float); + rnn.ws_cell_comp_size + = rnn.is_lbr || rnn.dt_conf != all_f32 + ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float) + : 0; + rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer + * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); + rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic + * sizeof(float); +} + +int rnn_utils::get_good_ld(int dim, int sizeof_dt) { + // we want matrices leading dimentions to be 64-byte aligned, + // and not divisible by 256 to avoid 4K aliasing effects + int ld = rnd_up(dim, 64 / sizeof_dt); + return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; +} + +void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_states_offset, size_t &ws_c_states_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size) { + + const size_t page_size = 4096; // 2097152; + size_t current_offset; + /* Mandatory workspaces: go to workspace if use_workspace, scratchpad + * otherwise */ + current_offset = 0; // assumes the workspace base pointer is page aligned + ws_gates_offset = current_offset; + current_offset += rnn.ws_gates_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_states_offset = current_offset; + current_offset += rnn.ws_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_c_states_offset = current_offset; + current_offset += rnn.ws_c_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_diff_states_offset = current_offset; + current_offset += rnn.ws_diff_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_grid_comp_offset = current_offset; + current_offset += rnn.ws_grid_comp_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_cell_comp_offset = current_offset; + current_offset += rnn.ws_cell_comp_size; + + workspace_size = rnn.use_workspace ? current_offset : 0; + + /* Optional scratchpads */ + // Assumes the scratchpad base pointer is page aligned. + // If use_workspace, the following goes to scratchpad alone, + // otherwise, all goes to scratchpad and continue incrementing offset + current_offset = rnn.use_workspace ? 0 : current_offset; + + if (rnn.copy_bias) { + current_offset = utils::rnd_up(current_offset, page_size); + ws_bias_offset = current_offset; + current_offset += rnn.ws_bias_size; + } + + scratchpad_size = current_offset; +} + +void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size) { + size_t ws_gates_offset, ws_states_offset, ws_c_states_offset, + ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset; + set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset, + ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset, scratchpad_size, workspace_size); +} + +status_t rnn_utils::set_good_strides( + memory_desc_t &weights_md, format_tag_t tag) { + auto &strides = weights_md.format_desc.blocking.strides; + auto dims = weights_md.dims; + + if (tag == ldigo) { + strides[2] = rnn_utils::get_good_ld((int)strides[2], + (int)types::data_type_size(weights_md.data_type)); + strides[1] = dims[2] * strides[2]; + strides[0] = dims[1] * strides[1]; + } else if (tag == ldgoi) { + strides[4] = rnn_utils::get_good_ld((int)strides[4], + (int)types::data_type_size(weights_md.data_type)); + strides[3] = dims[4] * strides[4]; + strides[1] = dims[3] * strides[3]; + strides[0] = dims[1] * strides[1]; + } else + return status::unimplemented; + + return status::success; +} + +status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, + memory_desc_t &weights_md, bool is_iter) { + using namespace format_tag; + bool use_packed_gemm = is_iter + ? rnn.use_iter_packed_gemm + : rnn.use_layer_packed_gemm; + if (use_packed_gemm) { + weights_md.format_kind = format_kind::rnn_packed; + rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; + rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p; + if (is_iter) { + rnn_pdata.n = rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_iter; + array_copy(rnn_pdata.parts, rnn.parts_weights_iter, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; + rnn_pdata.size = rnn.weights_iter_pack_size; + } else { + rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_layer; + array_copy(rnn_pdata.parts, rnn.parts_weights_layer, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; + rnn_pdata.size = rnn.weights_layer_pack_size; + } + } else { + CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + // Adjust strides for good leading dimension in GEMM + CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + } + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp new file mode 100644 index 0000000000..99eb787a64 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_UTILS_HPP +#define RNN_UTILS_HPP + +#include "mkldnn.h" + +#include "cpu_rnn_pd.hpp" + + +#define rnn_elemwise_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \ + src_data_t *states_t_l_, float *c_states_t_l_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_l_, float *diff_states_t_lp1_, \ + float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \ + float *ws_cell_) const + +#define rnn_cell_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \ + float *c_states_t_l_, float *diff_states_t_l_, \ + weights_data_t **w_layer_, weights_data_t **w_iter_, \ + float **bias_, src_data_t *states_t_lm1_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ + float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \ + acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const + +#define rnn_grid_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \ + weights_data_t **weights_states_, float **bias_, \ + src_data_t *ws_states_, float *ws_c_states_, \ + float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \ + float *ws_grid_, float *diff_weights_layer_, \ + float *diff_weights_iter_, float *diff_bias_) const + +#define rnn_gemm_sig(f) \ + void f(const char transA, const char transB, int m, int n, int k, \ + const float alpha, const weights_data_t *a_, const int ldA, \ + const src_data_t *b_, const int ldB, const float beta, \ + acc_data_t *c_, const int ldC) const + +#define rnn_bias_prepare_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \ + float *scratch_bias_) const + +#define rnn_bias_finalize_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \ + const float *w_iter_comp, const float *w_layer_comp) const + +#define rnn_weights_assign_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \ + int ld, int OC_size, int IC_size, const int n_parts, \ + const int *gates_per_part, const size_t *part_weights_pack_size, \ + weights_data_t **weights_, const weights_data_t *w_, \ + float **bias_, const float *b_, float *scratch_bias_) const + + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace rnn_utils { + +using namespace mkldnn::impl::utils; + +enum execution_direction_t { + l2r, + r2l, + bi_concat, + bi_sum, +}; + +enum data_type_conf_t { + all_f32, + u8u8u8f32, + f32u8f32f32, + u8u8u8u8, + f32u8f32u8 +}; + +struct rnn_conf_t { + execution_direction_t exec_dir; + data_type_conf_t dt_conf; + int n_layer, n_iter, n_dir, n_gates, n_states; + int mb; + int slc, sic, dic, dlc; + int gates_ld, gates_nld, gates_ws_ld; + int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS]; + int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS]; + int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS], + part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + bool weights_layer_is_packed, weights_iter_is_packed; + /* Size of packed data in bytes */ + size_t weights_layer_comp_offset, weights_layer_pack_size, + weights_iter_comp_offset, weights_iter_pack_size; + + bool copy_bias; + int weights_layer_ld, weights_layer_nld; + int diff_weights_layer_ld, diff_weights_layer_nld; + int weights_iter_ld, weights_iter_nld; + int diff_weights_iter_ld, diff_weights_iter_nld; + int states_nld, states_ws_ld; + int weights_iter_compensation_size, weights_layer_compensation_size; + bool is_fwd, is_training, is_lbr; + bool use_workspace; + + /* Size of workspace for each tensor in bytes */ + size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size, + ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size; + bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm, + use_iter_packed_gemm; +}; + +bool is_ldigo(const memory_desc_wrapper &md); +bool is_ldgoi(const memory_desc_wrapper &md); + +int get_good_ld(int dim, int sizeof_dt); + +void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d); + +void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d); + +void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_h_state_offset, size_t &ws_c_state_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size); + +void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size); +status_t set_expected_desc( + rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter); +status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); + +template +struct ws_gates_aoc { + ws_gates_aoc(const rnn_conf_t &rnn, T *data) + : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {} + T &operator()(int batch, int gate, int dic) { + return gates_(batch, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator gates_; + int DIC_; +}; +using ws_gates_aoc_t = ws_gates_aoc; +using ws_gates_aoc_s32_t = ws_gates_aoc; + +struct bias_aoc_t { + bias_aoc_t(const rnn_conf_t &rnn, const float *data) + : bias_(data, rnn.n_bias, rnn.dic) {} + const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator bias_; +}; + +template +struct ws_states_aoc { + ws_states_aoc(const rnn_conf_t &rnn, T *data) + : state_(data, rnn.states_nld, rnn.states_ws_ld) {} + T &operator()(int batch, int dic) { return state_(batch, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator state_; +}; +using ws_states_aoc_t = ws_states_aoc; +using ws_states_aoc_u8_t = ws_states_aoc; + +struct ws_diff_states_aoc_t { + ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld, + rnn.states_ws_ld) {} + float &operator()(int state_n, int batch, int dic) { + return diff_states_(state_n, 0, batch, dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator diff_states_; +}; + +struct ws_diff_w_iter_aoc_t { + ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_weights_iter_( + data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) + , DIC_(rnn.dic) {} + float &operator()(int sic, int gate, int dic) { + return diff_weights_iter_(sic, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator diff_weights_iter_; + int DIC_; +}; +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp new file mode 100644 index 0000000000..0420f87aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_concat.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +template +status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { + auto scratchpad = this->scratchpad(ctx); + auto iptrs = scratchpad.template get(key_concat_iptrs); + auto optrs = scratchpad.template get(key_concat_optrs); + auto nelems_to_copy = scratchpad.template get(key_concat_nelems); + auto is = scratchpad.template get(key_concat_istrides); + + const int num_arrs = pd()->n_inputs(); + const int *perm = pd()->perm_, *iperm = pd()->iperm_; + const int concat_dim = pd()->concat_dim(); + auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + const memory_desc_wrapper o_d(pd()->src_image_md(a)); + + iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + optrs[a] = o_base_ptr + o_d.blk_off(0); + nelems_to_copy[a] = pd()->nelems_to_concat(i_d); + for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) { + if (i < perm[concat_dim]) + is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]); + else + is[a][i] = 0; + } + } + + const memory_desc_wrapper o_d(pd()->src_image_md(0)); + + strides_t os = { 0 }; + for (int i = 0; i < perm[concat_dim]; i++) + os[i] = o_d.blocking_desc().strides[iperm[i]]; + + dims_t phys_dims; + for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++) + phys_dims[i] = (i < (size_t)perm[concat_dim]) + ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1; + + if (perm[concat_dim] == 0) { + for (int a = 0; a < num_arrs; ++a) { + const data_t *i = &iptrs[a][0]; + data_t *o = &optrs[a][0]; + parallel_nd((ptrdiff_t)nelems_to_copy[a], + [&](ptrdiff_t e) { o[e] = i[e]; }); + } + } else { + parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], + phys_dims[4], num_arrs, + [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) { + // XXX: this code may access uninitialized values in is[*][0-4] -- + // that's why we have to set them to zero although this is + // probably benign + size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2 + + is[a][3] * n3 + is[a][4] * n4; + size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2 + + os[3] * n3 + os[4] * n4; + const data_t *i = &iptrs[a][in_off]; + data_t *o = &optrs[a][out_off]; +#if defined(__GNUC__) && !defined(__INTEL_COMPILER) + // The code below performs data copying: o[e] = i[e] + // and uses a workaround to make GNU compilers optimize it + uint8_t *ptro = reinterpret_cast(o); + const uint8_t *ptri = reinterpret_cast(i); + const dim_t main_part = + nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t); + const dim_t tail_part = + nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t); + + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < main_part; ++e) { + *(reinterpret_cast(ptro)) + = *(reinterpret_cast(ptri)); + ptro += sizeof(uint32_t); + ptri += sizeof(uint32_t); + } + for (dim_t e = 0; e < tail_part; ++e) { + *ptro = *ptri; + ++ptro; + ++ptri; + } +#else + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e]; +#endif + }); + } + + return status::success; +} + +template struct simple_concat_t; +template struct simple_concat_t; +template struct simple_concat_t; +template struct simple_concat_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp new file mode 100644 index 0000000000..5177275452 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp @@ -0,0 +1,155 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_CONCAT_HPP +#define SIMPLE_CONCAT_HPP + +#include "memory_tracking.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct simple_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + int ndims = rhs.dst_md_.ndims; + utils::array_copy(perm_, rhs.perm_, ndims); + utils::array_copy(iperm_, rhs.iperm_, ndims); + utils::array_copy(blocks_, rhs.blocks_, ndims); + } + + DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); + + status_t init() { + const memory_desc_wrapper dst_d(dst_md()); + bool ok = true + && cpu_concat_pd_t::init() == status::success + && dst_d.ndims() <= 6; + if (!ok) return status::unimplemented; + + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + const memory_desc_wrapper o_d(&src_image_mds_[i]); + + const int ignore_strides = 0; + + ok = ok + && utils::everyone_is(data_type, i_d.data_type(), + o_d.data_type()) + && utils::everyone_is(format_kind::blocked, + i_d.format_kind(), o_d.format_kind()) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + o_d.blocking_desc(), ignore_strides) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + dst_d.blocking_desc(), ignore_strides) + && !i_d.is_additional_buffer(); + if (!ok) return status::unimplemented; + } + + dst_d.compute_blocks(blocks_); + format_perm(); + + // start dim is the first dimension after which the concatenation + // would happen contiguously + const int start_dim = perm_[concat_dim()]; + + // check that contiguous part is indeed contiguous (i.e. dense) + if (nelems_to_concat(dst_d) != + dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] + * dst_d.blocking_desc().strides[concat_dim()]) + return status::unimplemented; + + // check that all inputs have the same strides for the + // contiguous part [concat_dim .. ndims] for the *major* dims. + // the block part is already checked above + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + for (int d = start_dim; d < dst_d.ndims(); ++d) { + if (dst_d.blocking_desc().strides[iperm_[d]] + != i_d.blocking_desc().strides[iperm_[d]]) + return status::unimplemented; + } + } + + init_scratchpad(); + + return status::success; + } + + int perm_[MKLDNN_MAX_NDIMS]; + int iperm_[MKLDNN_MAX_NDIMS]; + dims_t blocks_; + + dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { + const int ndims = data_d.ndims(); + + dim_t nelems = 1; + for (int i = perm_[concat_dim()]; i < ndims; i++) + nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; + for (int i = 0; i < ndims; i++) + nelems *= blocks_[i]; + + return nelems; + } + + private: + void format_perm() { + const memory_desc_wrapper dst_d(dst_md()); + const int ndims = dst_d.ndims(); + + strides_t strides; + utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); + for (int i = 0; i < ndims; i++) iperm_[i] = i; + + utils::simultaneous_sort(strides, iperm_, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i; + } + + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs()); + scratchpad.book(key_concat_istrides, + sizeof(strides_t) * n_inputs()); + } + }; + + simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + typedef typename prec_traits::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp new file mode 100644 index 0000000000..e6c3b8d7af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_Q10N_HPP +#define CPU_SIMPLE_Q10N_HPP + +#include + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::math; + +template +inline out_t round_and_saturate(float f) +{ return math::saturate(out_round(f)); } + +/* Quantization with alpha == 1 and beta == 0 */ +template +struct qz_a1b0 { + out_t operator()(in_t in) + { return round_and_saturate((float)in); } +}; + +template +struct qz_a1b0::value + && !is_subset::value + >::type> { + out_t operator()(in_t in) { return math::saturate(in); } +}; + +template +struct qz_a1b0::value>::type> { + out_t operator()(in_t in) { return (out_t)in; } +}; + +/* Quantization with alpha == 1 */ +template struct qz_a1 { + out_t operator()(in_t in, out_t out, float beta) + { return round_and_saturate((float)in + beta * out); } +}; + +template struct qz_a1 { + float operator()(in_t in, float out, float beta) + { return (float)in + beta * out; } +}; + +/* Quantization with beta == 0 */ +template struct qz_b0 { + out_t operator()(in_t in, float alpha) + { return round_and_saturate(alpha * in); } +}; + +template struct qz_b0 { + float operator()(in_t in, float alpha) { return alpha * in; } +}; + +/* Quantization */ +template struct qz { + out_t operator()(in_t in, out_t out, float alpha, float beta) { + return round_and_saturate( + alpha * in + (beta ? beta * out : 0)); + } +}; + +template struct qz { + float operator()(in_t in, float out, float alpha, float beta) + { return alpha * in + (beta ? beta * out : 0); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp new file mode 100644 index 0000000000..ff845f5bd3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp @@ -0,0 +1,1022 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_REORDER_HPP +#define CPU_SIMPLE_REORDER_HPP + +#include + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "tag_traits.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_primitive.hpp" + +#include "simple_q10n.hpp" +#include "cpu_isa_traits.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::data_type; + +using bd = block_dim_t; +using ib = inner_blk_t; + +using namespace mkldnn::impl::utils; +using math::saturate; + +template +using data_t = typename prec_traits::type; + +template +using _qz_a1b0 = qz_a1b0, data_t>; + +template +using _qz = qz, data_t>; + +namespace fmt_order { + const bool keep = true; + const bool reverse = false; + const bool any = keep; +} + +namespace spec { +struct direct_copy {}; +struct direct_copy_except_dim_0 {}; +struct reference {}; +struct conv_s8s8 {}; +} + +#define SIMPLE_REORDER_TEMPL_DECL \ + impl::data_type_t type_i, impl::format_tag_t tag_i, \ + impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep +#define SIMPLE_REORDER_TEMPL_CALL \ + type_i, tag_i, type_o, tag_o, order_keep + +#define DECLARE_COMMON_PARAMS() \ + const memory_desc_wrapper &input_d = pd->src_md(); \ + const memory_desc_wrapper &output_d = pd->dst_md(); \ + const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \ + const float beta = pd->beta(); MAYBE_UNUSED(beta); + +/* specific reorders: common template */ +template +struct simple_reorder_impl {}; + +namespace { +inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i, + impl::format_tag_t tag_o, const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d) { + return input_d.matches_tag(order_keep ? tag_i : tag_o) + && output_d.matches_tag(order_keep ? tag_o : tag_i); +} +inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) { + if (many_scales_support) + return true; + return IMPLICATION(attr, attr->output_scales_.mask_ == 0); +} +} + +/* specific reorders: implementation */ +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = (input_d.dims()[tag_o == hwigo + 0]); + const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1; + + return output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = tag_o == hwigo; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int IC = dims[w_groups + 1]; + const int H = dims[w_groups + 2]; + const int W = dims[w_groups + 3]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W; + int32_t *cp = reinterpret_cast(output + offset); + + parallel_nd(G, OC, [&](int g, int oc) { + cp[g * OC + oc] = 0; + for (int ic = 0; ic < IC; ic++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = input[input_d.blk_off(g, oc, ic, h, w)]; + auto &o = output[output_d.blk_off(g, oc, ic, h, w)]; + const float s = scales[(D_mask == 1) ? 0 : g * OC + oc]; + + o = qz_b0, data_t>()( + i, s * adj_scale); + cp[g * OC + oc] -= (int32_t)o; + } + cp [g * OC + oc] *= 128; + }); + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + const int oc = (input_d.dims()[w_groups ? 1 : 0]); + const int g = w_groups ? input_d.dims()[0] : 1; + + return input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = + !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + constexpr int is_1d = + utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i); + constexpr int blksize = tag_traits::inner_blks == ib::_4b4c + ? 4 + : tag_traits::inner_blks == ib::_2c8b4c + ? 8 + : 16; + + const auto &_g_oihw_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int NB_OC = pdims[w_groups + 0] / blksize; + const int IC = dims[w_groups + 1]; + const int NB_IC = pdims[w_groups + 1] / blksize; + const int H = is_1d ? 1 : dims[w_groups + 2]; + const int W = dims[w_groups + 3 - is_1d]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t *inp, data_t *out, + int32_t *c, const float *s, const int oc_block, const int ic_block) { +# define index AB_or_BC_blk_off::inner_blks> + + for (int ic = 0; ic < ic_block; ++ic) { + for (int oc = 0; oc < oc_block; ++oc) { + const auto _g_oihw_off = + oc * _g_oihw_d.blocking_desc().strides[w_groups + 0] + + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1]; + out[index(oc, ic)] + = qz_b0, data_t>()( + inp[_g_oihw_off], s[oc] * adj_scale); + c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); + } + } +# undef index + }; + + constexpr int i_mult = blksize; + constexpr int o_mult = 1; + + size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W; + int32_t *cp = reinterpret_cast(output + offset); + parallel_nd(G * NB_OC * blksize, [&](int i) { + cp[i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) \ + : (md).blk_off(g, o, i, h, w)) + + parallel_nd(G, NB_OC, [&](int g, int O) { + for (int I = 0; I < NB_IC; I++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = &input[wei_blk_off( + input_d, g, i_mult * O, i_mult * I, h, w)]; + auto o = &output[wei_blk_off( + output_d, g, o_mult * O, o_mult * I, h, w)]; + const int oc_block = nstl::min(blksize, OC - O * blksize); + const int ic_block = nstl::min(blksize, IC - I * blksize); + + int _offset = (g * NB_OC + O) * blksize; + ker(i, o, (order_keep) ? &cp[_offset] : nullptr, + &scales[(D_mask == 1) ? 0 : _offset], + oc_block, ic_block); + } + }); + +# undef wei_blk_off + + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = input_d.dims()[1]; + const int g = input_d.dims()[0]; + + return true + && order_keep + && input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + constexpr bool is_1d = tag_i == goiw; + constexpr int blksize = 16; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + const int G = dims[0]; + const int Gp = pdims[0]; + const int OC = dims[1]; + const int IC = dims[2]; + const int H = is_1d ? 1 : dims[3]; + const int W = dims[4 - is_1d]; + + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + const float *scales = pd->attr()->output_scales_.scales_; + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t *inp, data_t *out, + int32_t *cp, const float *s, const int g_block) { + PRAGMA_OMP_SIMD() + for (int g = 0; g < g_block; g++) { + const auto i_off = g * input_d.blocking_desc().strides[0]; + out[g] = qz_b0, data_t>()( + inp[i_off], s[g * OC] * adj_scale); + cp[g * OC] -= 128 * (int32_t)(out[g]); + } + }; + + size_t cp_offset = output_d.size() - output_d.additional_buffer_size(); + int32_t *cp = reinterpret_cast(output + cp_offset); + parallel_nd((Gp/blksize) * OC, [&](int ib) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; i++) + cp[ib * blksize + i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w)) + + parallel_nd(Gp/blksize, OC, [&](int gb, int O) { + for (int I = 0; I < IC; I++) { + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) + { + const int g_block = nstl::min(G - gb * blksize, blksize); + const auto inp = &input[wei_blk_off( + input_d, gb * blksize, O, I, h, w)]; + const auto out = &output[wei_blk_off( + output_d, gb, O, I, h, w)]; + int offset = gb * blksize + O; + ker(inp, out, &cp[offset], + &scales[(D_mask == 1) ? 0 : offset], g_block); + } + } + }); + +# undef wei_blk_off + + return success; + } +}; + +/* reorders with tail support */ + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + constexpr int is_1d = tag_i == nCw8c; + constexpr int is_3d = tag_i == nCdhw8c; + constexpr int blksize_16 = 16; + constexpr int blksize_8 = 8; + constexpr int ic_mult = order_keep ? 2 : 1; + constexpr int oc_mult = order_keep ? 1 : 2; + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep ? output_d.padded_dims() + : input_d.padded_dims(); + + const int C = dims[1]; + const int D = is_3d ? dims[2] : 1; + const int H = is_1d ? 1 : dims[2 + is_3d]; + const int W = dims[3 + is_3d - is_1d]; + + auto ker = [&](const data_t *i, data_t *o, + const int block_16) { + const int nb = (block_16 - 1) / blksize_8 + 1; + if (alpha == 1.0 && beta == 0.0) { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz_a1b0()( + i[i_off + c]); + } + } + } else { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz()(i[i_off + c], + o[o_off + c], alpha, beta); + } + } + } + }; + +# define data_blk_off(md, n, c, d, h, w) \ + ( is_1d ? (md).blk_off(n, c, w) \ + : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) + + parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, + [&](int n, int nb_c, int d, int h, int w) { + auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)]; + auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)]; + const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16); + ker(i, o, block_16); + }); + +# undef data_blk_off + + return success; + } +}; + +#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ + static bool is_applicable(const memory_desc_wrapper &input_d, \ + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \ + return simple_attr_check(attr, false) && (order_keep \ + ? output_d.matches_tag(tag_o) && input_d.is_plain() \ + : input_d.matches_tag(tag_o) && output_d.is_plain()); \ + } + +template +struct simple_reorder_impl::block_dims == bd::_A + || tag_traits::block_dims == bd::_B) + && tag_traits::ndims >= 3 + && tag_traits::ndims <= 6 + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &block_d = order_keep ? output_d : input_d; + const auto &dims = input_d.dims(); + const auto &pdims = block_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + constexpr int blk_idx = tag_traits::block_dims == bd::_A ? 0 : 1; + + const dim_t H0 = dims[0]; + const dim_t H1 = dims[1]; + const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1; + const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; + const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; + const dim_t L = dims[ndims - 1]; + const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; + + constexpr int blksize = false ? 0 + : utils::one_of(tag_traits::inner_blks, ib::_4a, ib::_4b) ? 4 + : utils::one_of(tag_traits::inner_blks, ib::_8a, ib::_8b) ? 8 + : 16; + + auto ker = [&](const data_t *i, data_t *o, int block) { + if (alpha == 1.0 && beta == 0.0) { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz_a1b0()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0()( + i[l * l_blk_stride + blk]); + } + } + } else { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz()( + i[flat_off], o[l * blksize + blk], + alpha, beta); + } else { + o[flat_off] = _qz()( + i[l * l_blk_stride + blk], o[flat_off], + alpha, beta); + } + } + } + }; + +# define off(md, h0, h1, m0, m1, m2) \ + (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ + : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ + : ndims >= 4 ? (md).blk_off(h0, h1, m2) \ + : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) + + constexpr int i_mult = order_keep ? blksize : 1; + constexpr int o_mult = order_keep ? 1 : blksize; + + if (blk_idx == 0) { + const dim_t BH0 = pdims[0] / blksize; + parallel_nd(BH0, H1, M0, M1, M2, + [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)]; + auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)]; + const int block = nstl::min(blksize, H0 - bh0 * blksize); + ker(i, o, block); + }); + } else if (blk_idx == 1) { + const dim_t BH1 = pdims[1] / blksize; + parallel_nd(H0, BH1, M0, M1, M2, + [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)]; + auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)]; + const int block = nstl::min(blksize, H1 - bh1 * blksize); + ker(i, o, block); + }); + } else { + assert(!"unimplemented"); + } + +# undef off + + return success; + } +}; + +template +struct simple_reorder_impl::block_dims == bd::_AB + || tag_traits::block_dims == bd::_BC) + && IMPLICATION(tag_traits::block_dims == bd::_AB, + tag_traits::ndims >= 3 && tag_traits::ndims <= 5) + && IMPLICATION(tag_traits::block_dims == bd::_BC, + tag_traits::ndims >= 4 && tag_traits::ndims <= 6) + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + constexpr int ndims = tag_traits::ndims; + + static constexpr bool with_g = tag_traits::block_dims == bd::_BC; + const dim_t G = with_g ? dims[0] : 1; + + const dim_t H0 = dims[0 + with_g]; + const dim_t H1 = dims[1 + with_g]; + + const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; + const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; + const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; + + constexpr int blksize_0 = false ? 0 + : utils::one_of(tag_traits::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b) + ? 4 + : utils::one_of(tag_traits::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c, + ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b, + ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c) + ? 16 : INT_MIN; + + constexpr int blksize_1 = utils::one_of(tag_traits::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits::inner_blks, + ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c) + ? 16 + : utils::one_of(tag_traits::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b, + ib::_16a4b, ib::_16b4c) + ? 4 + : INT_MIN; + + const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; + const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; + + auto ker = [&](const data_t *i, data_t *o, + const int block_h0, const int block_h1) { +# define blk_off AB_or_BC_blk_off::inner_blks> + + if (alpha == 1.0 && beta == 0.0) { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz_a1b0()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0()( + i[blk_off(h0, h1)]); + } + } + } else { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz()(i[flat_off], + o[blk_off(h0, h1)], alpha, beta); + } else { + o[flat_off] = _qz()(i[blk_off(h0, h1)], + o[flat_off], alpha, beta); + } + } + } + +# undef blk_off + }; + + constexpr int i_mult_0 = order_keep ? blksize_0 : 1; + constexpr int o_mult_0 = order_keep ? 1 : blksize_0; + + constexpr int i_mult_1 = order_keep ? blksize_1 : 1; + constexpr int o_mult_1 = order_keep ? 1 : blksize_1; + +# define off(md, g, h0, h1, m0, m1, m2) \ + (ndims >= 5 + with_g ? (md).blk_off(g, h0, h1, m0, m1, m2) \ + : ndims >= 4 + with_g ? (md).blk_off(g, h0, h1, m1, m2) \ + : /* ndims >= 3 + with_g ? */ (md).blk_off(g, h0, h1, m2)) + + parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, + [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, + g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)]; + auto o = &output[off(output_d, + g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)]; + const int block_h0 = nstl::min(blksize_0, H0 - nb_h0 * blksize_0); + const int block_h1 = nstl::min(blksize_1, H1 - nb_h1 * blksize_1); + ker(i, o, block_h0, block_h1); + }); + +# undef off + + return success; + } +}; + +/* generic and direct-copy reorders */ + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 0) + && input_d.is_dense() && output_d.is_dense() + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + assert(input_d.is_dense()); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const size_t nelems = input_d.nelems(); + + constexpr int block_size = 16; + const auto num_blocks = nelems / block_size; + const auto rem_elems = nelems % block_size; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(num_blocks, nthr, ithr, start, end); + start = start * block_size; + end = end * block_size; + + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1b0, data_t>() + (input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1, data_t>() + (input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_b0, data_t>() + (input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz, data_t>() + (input[e], output[e], alpha, beta); + } + } + + if (rem_elems != 0 && ithr == nthr - 1){ + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1b0, + data_t>()(input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1, + data_t>()(input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_b0, + data_t>()(input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz, data_t>() + (input[e], output[e], alpha, beta); + } + } + } + }); + return success; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { + return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); + }; + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 1) + && is_dense_no_0(input_d) && is_dense_no_0(output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const int N = input_d.dims()[0]; + const dim_t is = input_d.blocking_desc().strides[0]; + const dim_t os = output_d.blocking_desc().strides[0]; + const dim_t nelems_no_d0 = nelems_no_dim_0(input_d); + const dim_t work_amount = N * nelems_no_d0; + + if (alpha == 1.0 && beta == 0.0) { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = dim1_s + work_rem > nelems_no_d0 + ? nelems_no_d0 : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e) { + output[os * n + e] = _qz_a1b0()( + input[is * n + e]); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } else { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = + dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0 + : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e){ + output[os * n + e] = _qz()( + input[is * n + e], output[os * n + e], alpha, + beta); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } + + return success; + } + +private: + static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) { + const int ndims = data_d.ndims(); + if (ndims <= 1) return 1; + return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1); + } + + static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) { + dims_t blocks; + data_d.compute_blocks(blocks); + + const auto &blk = data_d.blocking_desc(); + + dim_t blk_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) + blk_size *= blk.inner_blks[iblk]; + + dim_t max_size = blk_size; + for (int d = 1; d < data_d.ndims(); ++d) { + max_size = nstl::max(max_size, + data_d.padded_dims()[d] / blocks[d] * blk.strides[d]); + } + + return max_size; + } +}; + +template +struct simple_reorder_impl::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* supported smask: 0x0...011..10...0, + * i.e. 1 should be contiguous */ + int smask = attr ? attr->output_scales_.mask_ : 0; + for (; smask > 0 && !(smask & 0x1); smask >>= 1); + for (; smask > 0 && smask & 0x1; smask >>= 1); + return true + && input_d.is_blocking_desc() + && output_d.is_blocking_desc() + && !output_d.is_additional_buffer() + && !input_d.is_additional_buffer() + && smask == 0; + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t *input, data_t *output) { + DECLARE_COMMON_PARAMS(); + + const size_t nelems = input_d.nelems(); + + int ndims_start = 0, ndims_mask = 0; + int smask = pd->attr()->output_scales_.mask_; + for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start; + for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask; + assert(smask == 0); + + const ptrdiff_t D_start + = utils::array_product(input_d.dims(), ndims_start); + const ptrdiff_t D_mask + = utils::array_product(input_d.dims() + ndims_start, ndims_mask); + const ptrdiff_t D_rest = nelems / D_start / D_mask; + + const float *scales = pd->attr()->output_scales_.scales_; + + parallel_nd(D_start, D_mask, D_rest, + [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) { + const float scale = scales[dm]; + + const size_t e = (ds * D_mask + dm) * D_rest + dr; + const auto &i = input[input_d.off_l(e)]; + auto &o = output[output_d.off_l(e)]; + + o = _qz()(i, o, scale, beta); + }); + + return success; + } +}; + + +/* high level class declaration */ + +template +struct simple_reorder_t: public cpu_primitive_t { + struct pd_t: public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("simple:any", simple_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + bool args_ok = true + && src_md->data_type == type_i + && dst_md->data_type == type_o + && simple_reorder_impl:: + is_applicable(src_md, dst_md, attr); + if (!args_ok) + return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign(*reorder_pd, _pd); + } + }; + + simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_TO); + simple_reorder_impl::execute( + pd(), input, output); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +#undef SIMPLE_REORDER_TEMPL_DECL +#undef SIMPLE_REORDER_TEMPL_CALL + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp new file mode 100644 index 0000000000..f0947573a9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_sum.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +status_t simple_sum_t::execute(const exec_ctx_t &ctx) const { + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper o_d(pd()->dst_md()); + output += o_d.blk_off(0); + + const int num_arrs = pd()->n_inputs(); + const data_t *input_ptrs[max_num_arrs]; + const size_t nelems = o_d.nelems(); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + } + + const size_t block_size = 16 * 1024 / sizeof(data_type); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + + const auto scales = pd()->scales(); + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + }); + + return status::success; +} + +template struct simple_sum_t; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp new file mode 100644 index 0000000000..2a0187a184 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_SUM_HPP +#define SIMPLE_SUM_HPP + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct simple_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + DECLARE_SUM_PD_T("simple:any", simple_sum_t); + + status_t init() { + const int n = n_inputs(); + + bool ok = true + && cpu_sum_pd_t::init() == status::success + && n <= max_num_arrs; + if (!ok) return status::unimplemented; + + const memory_desc_wrapper o_d(dst_md()); + ok = ok + && o_d.data_type() == data_type + && o_d.is_dense(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n; ++i) { + const memory_desc_wrapper i_d(src_md(i)); + if (i_d != o_d) return status::unimplemented; + } + + return status::success; + } + }; + + simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + enum {max_num_arrs = 16 }; + typedef typename prec_traits::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp new file mode 100644 index 0000000000..c2082d7d62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp @@ -0,0 +1,376 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_WINO_REORDER_HPP +#define CPU_WINO_REORDER_HPP + +#include "mkldnn_thread.hpp" + +#include "simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template +struct wino_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_tag(utils::pick(id.ndims() - 4, + format_tag::oihw, format_tag::goihw)) + && od.format_kind() == format_kind::wino + && utils::one_of(od.wino_desc().wino_format, + mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + auto &o = memory_desc_wrapper(dst_md()).wino_desc(); + size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; + size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_reorder_wino_transform_space, + sizeof(in_data_t) * transform_space_size); + scratchpad.book(key_reorder_wino_plain, + sizeof(out_data_t) * plain_size); + } + }; + +private: + typedef typename prec_traits::type in_data_t; + typedef typename prec_traits::type out_data_t; + const int unsign_val_in_wino_domain_ = 5; + + wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + r_ = dst_d.wino_desc().r; + w_alpha_ = dst_d.wino_desc().alpha; + wino_format_ = dst_d.wino_desc().wino_format; + + const auto &in_dims = src_d.dims(); + int groups; + int groups_offset; + if (src_d.ndims() == 5) { + groups = in_dims[0]; + groups_offset = 1; + } else { + groups = 1; + groups_offset = 0; + } + assert(groups == 1); // groups are not supported now + MAYBE_UNUSED(groups); + + or_oc_ = in_dims[0 + groups_offset]; + or_ic_ = in_dims[1 + groups_offset]; + kh_ = in_dims[2 + groups_offset]; + kw_ = in_dims[3 + groups_offset]; + + oc_ = dst_d.wino_desc().oc; + ic_ = dst_d.wino_desc().ic; + oc_block_ = dst_d.wino_desc().oc_block; + ic_block_ = dst_d.wino_desc().ic_block; + assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); + nb_oc_ = oc_ / oc_block_; + nb_ic_ = ic_ / ic_block_; + ic2_block_ = 1; + if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + ic2_block_ = dst_d.wino_desc().ic2_block; + oc2_block_ = dst_d.wino_desc().oc2_block; + assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); + + adj_scale_ = dst_d.wino_desc().adj_scale; + + size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; + size_wspace_ = r_ * w_alpha_ * oc_block_; + } + + void transform(out_data_t *__restrict tmp_wei, + const in_data_t *__restrict input, + in_data_t *__restrict wspace) const { + const memory_desc_wrapper src_d(pd()->src_md()); + + const int smask = pd()->attr()->output_scales_.mask_; + const int ndims_mask = math::ilog2q(smask + 1); + const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); + const float *__restrict scales = pd()->attr()->output_scales_.scales_; + assert(D_mask == 1 || D_mask == (size_t)oc_); + + /* transform weights to winograd domain */ + const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, + { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; + + const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, + { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, + { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, + { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, + { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, + { 0.f, 0.f, 1.f } }; + + float *__restrict g; + if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) + g = (float *)G_2x2_3x3; + else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + g = (float *)G_4x4_3x3; + else { + assert("Unknown winograd weights target layout"); + return; + } + + int Z = oc_ * ic_; + assert(r_ == kh_ && r_ == kw_); + + for (int iic = 0; iic < ic_; iic++) { + for (int ob = 0; ob < nb_oc_; ob++) { + const in_data_t *__restrict _inp + = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; + out_data_t *__restrict _out + = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; + + for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); + + for_nd(0, 1, r_, w_alpha_, oc_block_, + [&](int ih, int j, int ioc) { + for (int iw = 0; iw < r_; ++iw) { + int inp_oc = ob * oc_block_ + ioc; + int inp_ic = iic; + in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) + ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] + : 0.f; + wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] + += inp_v * g[j * r_ + iw]; + } + }); + + for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, + [&](int i, int j, int ioc) { + float t = 0; + for (int k = 0; k < r_; ++k) + t += g[i * r_ + k] + * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; + if (type_o == data_type::s8) { + const float scale = (D_mask == 1) + ? scales[0] + : scales[ob * oc_block_ + ioc]; + _out[(i * w_alpha_ + j) * Z + ioc] + = qz_b0()( + (in_data_t)t, scale * adj_scale_); + } else { + _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; + } + }); + }} + } + + void reorder_to_aaOIoi(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int32_t *__restrict dst_bias = nullptr; + if (type_o == data_type::s8) { + const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; + const size_t bias_size = w_alpha_ * w_alpha_ * oc_; + + dst_bias = (int32_t *)(output + bias_shift); + utils::array_set((int32_t *)dst_bias, 0, bias_size); + } + int index = 0; + for (int u_h = 0; u_h < w_alpha_; u_h++) { + for (int u_w = 0; u_w < w_alpha_; u_w++) { + for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { + int u_h_shift = u_h * w_alpha_ * ic_ * oc_; + int u_w_shift = u_w * ic_ * oc_; + int u_h_shift_b = u_h * w_alpha_ * oc_; + int u_w_shift_b = u_w * oc_; + int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int _i = ib * ic_block_; + int _o = ob * oc_block_; + int ic_shift = (_i + i) * oc_; + int oc_shift = (_o + o); + int ic_block_shift = ib * oc_block_ * ic_block_ + i; + int src_offset = + u_h_shift + u_w_shift + ic_shift + oc_shift; + int dst_offset = u_h_shift + u_w_shift + oc_block_shift + + ic_block_shift; + + output[dst_offset] = tmp_wei[src_offset]; + if (type_o == data_type::s8) { + int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; + if (index != unsign_val_in_wino_domain_) + dst_bias[bias_offset] + -= (128 * (int32_t)output[dst_offset]); + else + dst_bias[bias_offset] = 0; + } + }} + }); + index++; + }} + } + + void reorder_to_aaOio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, + [&](int u_h, int u_w, int ob) { + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + for (int o = 0; o < oc_block_; o++) { + int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ + + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); + + int dst_offset + = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + ob * nb_ic_ * ic_block_ * oc_block_ + + ib * ic_block_ * oc_block_ + i * oc_block_ + o; + output[dst_offset] = tmp_wei[src_offset]; + }}} + }); + } + + void reorder_to_aaOBiOo(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, + [&](int u_h, int u_w, int occ) { + for (int ib = 0; ib < nb_ic_; ib++) { + out_data_t *__restrict wei_ptr = output + + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) + * oc2_block_ * ic_block_ * oc_block_; + int wei_offset = 0; + for (int i = 0; i < ic_block_; i++) { + for (int ob2 = 0; ob2 < oc2_block_; ob2++) { + for (int o = 0; o < oc_block_; o++) { + int icp = ib * ic_block_ + i; + int ocp = + occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + wei_ptr[wei_offset + o] = tmp_wei[src_offset]; + } + wei_offset += oc_block_; + }} + } + }); + } + + void reorder_to_OBaaIBOIio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int ic_chunks = nb_ic_ / ic2_block_; + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, + [&](int occ, int u_h, int u_w) { + for (int icc = 0; icc < ic_chunks; icc++) { + for (int ob = 0; ob < oc2_block_; ob++) { + int ocp = (occ * oc2_block_ + ob) * oc_block_; + for (int ib = 0; ib < ic2_block_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int icp = (icc * ic2_block_ + ib) * ic_block_ + i; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + int wei_offset + = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) + * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ + + ib) * ic_block_ + i) * oc_block_; + for (int o = 0; o < oc_block_; o++) + output[wei_offset + o] = tmp_wei[src_offset + o]; + }} + }} + }); + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + + auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_wino_transform_space); + auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get( + memory_tracking::names::key_reorder_wino_plain); + + transform(tmp_wei, input, wspace); + + /* reorder to winograd domain */ + switch (wino_format_) { + case mkldnn_wino_wei_aaOIoi: + reorder_to_aaOIoi(output, tmp_wei); break; + case mkldnn_wino_wei_aaOio: + reorder_to_aaOio(output, tmp_wei); break; + case mkldnn_wino_wei_aaOBiOo: + reorder_to_aaOBiOo(output, tmp_wei); break; + case mkldnn_wino_wei_OBaaIBOIio: + reorder_to_OBaaIBOIio(output, tmp_wei); break; + default: assert("Unknown wino format"); break; + } + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int r_, w_alpha_; + int ic_, oc_, or_ic_, or_oc_, kh_, kw_; + int oc_block_, ic_block_, oc2_block_, ic2_block_; + float adj_scale_; + int nb_oc_, nb_ic_; + mkldnn_wino_memory_format_t wino_format_; + int size_wino_wei_; + int size_wspace_; +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT new file mode 100644 index 0000000000..66b6ea55d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT @@ -0,0 +1,47 @@ + +Copyright (c) 2007 MITSUNARI Shigeo +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. +Neither the name of the copyright owner nor the names of its contributors may +be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. +----------------------------------------------------------------------------- +ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た +す場合に限り、再頒布および使用が許可されます。 + +ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 +を含めること。 +バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 +権表示、本条件一覧、および下記免責条項を含めること。 +書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 +に、著作権者の名前またはコントリビューターの名前を使用してはならない。 +本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ +れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 +に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 +著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを +問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で +あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 +本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の +喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 +損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 +一切責任を負わないものとします。 diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h new file mode 100644 index 0000000000..cf5771332f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h @@ -0,0 +1,2658 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#pragma once +#ifndef XBYAK_XBYAK_H_ +#define XBYAK_XBYAK_H_ +/*! + @file xbyak.h + @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++ + @author herumi + @url https://github.com/herumi/xbyak + @note modified new BSD license + http://opensource.org/licenses/BSD-3-Clause +*/ +#ifndef XBYAK_NO_OP_NAMES + #if not +0 // trick to detect whether 'not' is operator or not + #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()." + #endif +#endif + +#include // for debug print +#include +#include +#include +#include +#ifndef NDEBUG +#include +#endif + +// #define XBYAK_DISABLE_AVX512 + +//#define XBYAK_USE_MMAP_ALLOCATOR +#if !defined(__GNUC__) || defined(__MINGW32__) + #undef XBYAK_USE_MMAP_ALLOCATOR +#endif + +#ifdef __GNUC__ + #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor)) +#else + #define XBYAK_GNUC_PREREQ(major, minor) 0 +#endif + +// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft. +#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\ + ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__))) + #include + #define XBYAK_STD_UNORDERED_SET std::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap + +/* + Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using + libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version). +*/ +#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600) + #include + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#else + #include + #define XBYAK_STD_UNORDERED_SET std::set + #include + #define XBYAK_STD_UNORDERED_MAP std::map + #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap +#endif +#ifdef _WIN32 + #include + #include + #include +#elif defined(__GNUC__) + #include + #include + #include +#endif +#if !defined(_MSC_VER) || (_MSC_VER >= 1600) + #include +#endif + +#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__)) + #define XBYAK64_WIN +#elif defined(__x86_64__) + #define XBYAK64_GCC +#endif +#if !defined(XBYAK64) && !defined(XBYAK32) + #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN) + #define XBYAK64 + #else + #define XBYAK32 + #endif +#endif + +#if (__cplusplus >= 201103) || (_MSC_VER >= 1800) + #define XBYAK_VARIADIC_TEMPLATE +#endif + +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) /* remove inline function */ + #pragma warning(disable : 4786) /* identifier is too long */ + #pragma warning(disable : 4503) /* name is too long */ + #pragma warning(disable : 4127) /* constant expresison */ +#endif + +namespace Xbyak { + +enum { + DEFAULT_MAX_CODE_SIZE = 4096, + VERSION = 0x5760 /* 0xABCD = A.BC(D) */ +}; + +#ifndef MIE_INTEGER_TYPE_DEFINED +#define MIE_INTEGER_TYPE_DEFINED +#ifdef _MSC_VER + typedef unsigned __int64 uint64; + typedef __int64 sint64; +#else + typedef uint64_t uint64; + typedef int64_t sint64; +#endif +typedef unsigned int uint32; +typedef unsigned short uint16; +typedef unsigned char uint8; +#endif + +#ifndef MIE_ALIGN + #ifdef _MSC_VER + #define MIE_ALIGN(x) __declspec(align(x)) + #else + #define MIE_ALIGN(x) __attribute__((aligned(x))) + #endif +#endif +#ifndef MIE_PACK // for shufps + #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w)) +#endif + +enum { + ERR_NONE = 0, + ERR_BAD_ADDRESSING, + ERR_CODE_IS_TOO_BIG, + ERR_BAD_SCALE, + ERR_ESP_CANT_BE_INDEX, + ERR_BAD_COMBINATION, + ERR_BAD_SIZE_OF_REGISTER, + ERR_IMM_IS_TOO_BIG, + ERR_BAD_ALIGN, + ERR_LABEL_IS_REDEFINED, + ERR_LABEL_IS_TOO_FAR, + ERR_LABEL_IS_NOT_FOUND, + ERR_CODE_ISNOT_COPYABLE, + ERR_BAD_PARAMETER, + ERR_CANT_PROTECT, + ERR_CANT_USE_64BIT_DISP, + ERR_OFFSET_IS_TOO_BIG, + ERR_MEM_SIZE_IS_NOT_SPECIFIED, + ERR_BAD_MEM_SIZE, + ERR_BAD_ST_COMBINATION, + ERR_OVER_LOCAL_LABEL, // not used + ERR_UNDER_LOCAL_LABEL, + ERR_CANT_ALLOC, + ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW, + ERR_BAD_PROTECT_MODE, + ERR_BAD_PNUM, + ERR_BAD_TNUM, + ERR_BAD_VSIB_ADDRESSING, + ERR_CANT_CONVERT, + ERR_LABEL_ISNOT_SET_BY_L, + ERR_LABEL_IS_ALREADY_SET_BY_L, + ERR_BAD_LABEL_STR, + ERR_MUNMAP, + ERR_OPMASK_IS_ALREADY_SET, + ERR_ROUNDING_IS_ALREADY_SET, + ERR_K0_IS_INVALID, + ERR_EVEX_IS_INVALID, + ERR_SAE_IS_INVALID, + ERR_ER_IS_INVALID, + ERR_INVALID_BROADCAST, + ERR_INVALID_OPMASK_WITH_MEMORY, + ERR_INVALID_ZERO, + ERR_INVALID_RIP_IN_AUTO_GROW, + ERR_INVALID_MIB_ADDRESS, + ERR_INTERNAL, + ERR_X2APIC_IS_NOT_SUPPORTED +}; + +class Error : public std::exception { + int err_; +public: + explicit Error(int err) : err_(err) + { + if (err_ < 0 || err_ > ERR_INTERNAL) { + fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_); + //exit(1); + } + } + operator int() const { return err_; } + const char *what() const throw() + { + static const char *errTbl[] = { + "none", + "bad addressing", + "code is too big", + "bad scale", + "esp can't be index", + "bad combination", + "bad size of register", + "imm is too big", + "bad align", + "label is redefined", + "label is too far", + "label is not found", + "code is not copyable", + "bad parameter", + "can't protect", + "can't use 64bit disp(use (void*))", + "offset is too big", + "MEM size is not specified", + "bad mem size", + "bad st combination", + "over local label", + "under local label", + "can't alloc", + "T_SHORT is not supported in AutoGrow", + "bad protect mode", + "bad pNum", + "bad tNum", + "bad vsib addressing", + "can't convert", + "label is not set by L()", + "label is already set by L()", + "bad label string", + "err munmap", + "opmask is already set", + "rounding is already set", + "k0 is invalid", + "evex is invalid", + "sae(suppress all exceptions) is invalid", + "er(embedded rounding) is invalid", + "invalid broadcast", + "invalid opmask with memory", + "invalid zero", + "invalid rip in AutoGrow", + "invalid mib address", + "internal error", + "x2APIC is not supported" + }; + assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl)); + return errTbl[err_]; + } +}; + +inline const char *ConvertErrorToString(const Error& err) +{ + return err.what(); +} + +inline void *AlignedMalloc(size_t size, size_t alignment) +{ +#ifdef __MINGW32__ + return __mingw_aligned_malloc(size, alignment); +#elif defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + void *p; + int ret = posix_memalign(&p, alignment, size); + return (ret == 0) ? p : 0; +#endif +} + +inline void AlignedFree(void *p) +{ +#ifdef __MINGW32__ + __mingw_aligned_free(p); +#elif defined(_MSC_VER) + _aligned_free(p); +#else + free(p); +#endif +} + +template +inline const To CastTo(From p) throw() +{ + return (const To)(size_t)(p); +} +namespace inner { + +static const size_t ALIGN_PAGE_SIZE = 4096; + +inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; } +inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; } + +inline uint32 VerifyInInt32(uint64 x) +{ +#ifdef XBYAK64 + if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + return static_cast(x); +} + +enum LabelMode { + LasIs, // as is + Labs, // absolute + LaddTop // (addr + top) for mov(reg, label) with AutoGrow +}; + +} // inner + +/* + custom allocator +*/ +struct Allocator { + virtual uint8 *alloc(size_t size) { return reinterpret_cast(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); } + virtual void free(uint8 *p) { AlignedFree(p); } + virtual ~Allocator() {} + /* override to return false if you call protect() manually */ + virtual bool useProtect() const { return true; } +}; + +#ifdef XBYAK_USE_MMAP_ALLOCATOR +class MmapAllocator : Allocator { + typedef XBYAK_STD_UNORDERED_MAP SizeList; + SizeList sizeList_; +public: + uint8 *alloc(size_t size) + { + const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1; + size = (size + alignedSizeM1) & ~alignedSizeM1; +#ifdef MAP_ANONYMOUS + const int mode = MAP_PRIVATE | MAP_ANONYMOUS; +#elif defined(MAP_ANON) + const int mode = MAP_PRIVATE | MAP_ANON; +#else + #error "not supported" +#endif + void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0); + if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC); + assert(p); + sizeList_[(uintptr_t)p] = size; + return (uint8*)p; + } + void free(uint8 *p) + { + if (p == 0) return; + SizeList::iterator i = sizeList_.find((uintptr_t)p); + if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER); + if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP); + sizeList_.erase(i); + } +}; +#endif + +class Address; +class Reg; + +class Operand { + static const uint8 EXT8BIT = 0x20; + unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil + unsigned int kind_:9; + unsigned int bit_:10; +protected: + unsigned int zero_:1; + unsigned int mask_:3; + unsigned int rounding_:3; + void setIdx(int idx) { idx_ = idx; } +public: + enum Kind { + NONE = 0, + MEM = 1 << 0, + REG = 1 << 1, + MMX = 1 << 2, + FPU = 1 << 3, + XMM = 1 << 4, + YMM = 1 << 5, + ZMM = 1 << 6, + OPMASK = 1 << 7, + BNDREG = 1 << 8 + }; + enum Code { +#ifdef XBYAK64 + RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15, + R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D, + R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W, + R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B, + SPL = 4, BPL, SIL, DIL, +#endif + EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI, + AX = 0, CX, DX, BX, SP, BP, SI, DI, + AL = 0, CL, DL, BL, AH, CH, DH, BH + }; + Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { } + Operand(int idx, Kind kind, int bit, bool ext8bit = 0) + : idx_(static_cast(idx | (ext8bit ? EXT8BIT : 0))) + , kind_(kind) + , bit_(bit) + , zero_(0), mask_(0), rounding_(0) + { + assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two + } + Kind getKind() const { return static_cast(kind_); } + int getIdx() const { return idx_ & (EXT8BIT - 1); } + bool isNone() const { return kind_ == 0; } + bool isMMX() const { return is(MMX); } + bool isXMM() const { return is(XMM); } + bool isYMM() const { return is(YMM); } + bool isZMM() const { return is(ZMM); } + bool isXMEM() const { return is(XMM | MEM); } + bool isYMEM() const { return is(YMM | MEM); } + bool isZMEM() const { return is(ZMM | MEM); } + bool isOPMASK() const { return is(OPMASK); } + bool isBNDREG() const { return is(BNDREG); } + bool isREG(int bit = 0) const { return is(REG, bit); } + bool isMEM(int bit = 0) const { return is(MEM, bit); } + bool isFPU() const { return is(FPU); } + bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; } + bool isExtIdx() const { return (getIdx() & 8) != 0; } + bool isExtIdx2() const { return (getIdx() & 16) != 0; } + bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); } + bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); } + bool hasZero() const { return zero_; } + int getOpmaskIdx() const { return mask_; } + int getRounding() const { return rounding_; } + void setKind(Kind kind) + { + if ((kind & (XMM|YMM|ZMM)) == 0) return; + kind_ = kind; + bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512; + } + void setBit(int bit) { bit_ = bit; } + void setOpmaskIdx(int idx, bool ignore_idx0 = false) + { + if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID); + if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET); + mask_ = idx; + } + void setRounding(int idx) + { + if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET); + rounding_ = idx; + } + void setZero() { zero_ = true; } + // ah, ch, dh, bh? + bool isHigh8bit() const + { + if (!isBit(8)) return false; + if (isExt8bit()) return false; + const int idx = getIdx(); + return AH <= idx && idx <= BH; + } + // any bit is accetable if bit == 0 + bool is(int kind, uint32 bit = 0) const + { + return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16) + } + bool isBit(uint32 bit) const { return (bit_ & bit) != 0; } + uint32 getBit() const { return bit_; } + const char *toString() const + { + const int idx = getIdx(); + if (kind_ == REG) { + if (isExt8bit()) { + static const char *tbl[4] = { "spl", "bpl", "sil", "dil" }; + return tbl[idx - 4]; + } + static const char *tbl[4][16] = { + { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" }, + { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" }, + { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" }, + { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" }, + }; + return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx]; + } else if (isOPMASK()) { + static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" }; + return tbl[idx]; + } else if (isZMM()) { + static const char *tbl[32] = { + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31" + }; + return tbl[idx]; + } else if (isYMM()) { + static const char *tbl[32] = { + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", + "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31" + }; + return tbl[idx]; + } else if (isXMM()) { + static const char *tbl[32] = { + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", + "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31" + }; + return tbl[idx]; + } else if (isMMX()) { + static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" }; + return tbl[idx]; + } else if (isFPU()) { + static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" }; + return tbl[idx]; + } else if (isBNDREG()) { + static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" }; + return tbl[idx]; + } + throw Error(ERR_INTERNAL); + } + bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; } + bool operator==(const Operand& rhs) const; + bool operator!=(const Operand& rhs) const { return !operator==(rhs); } + const Address& getAddress() const; + const Reg& getReg() const; +}; + +class Label; + +struct Reg8; +struct Reg16; +struct Reg32; +#ifdef XBYAK64 +struct Reg64; +#endif +class Reg : public Operand { +public: + Reg() { } + Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { } + Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); } + uint8 getRexW() const { return isREG(64) ? 8 : 0; } + uint8 getRexR() const { return isExtIdx() ? 4 : 0; } + uint8 getRexX() const { return isExtIdx() ? 2 : 0; } + uint8 getRexB() const { return isExtIdx() ? 1 : 0; } + uint8 getRex(const Reg& base = Reg()) const + { + uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB(); + if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40; + return rex; + } + Reg8 cvt8() const; + Reg16 cvt16() const; + Reg32 cvt32() const; +#ifdef XBYAK64 + Reg64 cvt64() const; +#endif +}; + +inline const Reg& Operand::getReg() const +{ + assert(!isMEM()); + return static_cast(*this); +} + +struct Reg8 : public Reg { + explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { } +}; + +struct Reg16 : public Reg { + explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { } +}; + +struct Mmx : public Reg { + explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { } +}; + +struct EvexModifierRounding { + enum { + T_RN_SAE = 1, + T_RD_SAE = 2, + T_RU_SAE = 3, + T_RZ_SAE = 4, + T_SAE = 5 + }; + explicit EvexModifierRounding(int rounding) : rounding(rounding) {} + int rounding; +}; +struct EvexModifierZero{EvexModifierZero() {}}; + +struct Xmm : public Mmx { + explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { } + Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { } + Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; } + Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; } + Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; } +}; + +struct Ymm : public Xmm { + explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { } + Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Zmm : public Ymm { + explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { } + Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Opmask : public Reg { + explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {} +}; + +struct BoundsReg : public Reg { + explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {} +}; + +templateT operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; } +templateT operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; } +templateT operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; } + +struct Fpu : public Reg { + explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { } +}; + +struct Reg32e : public Reg { + explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {} +}; +struct Reg32 : public Reg32e { + explicit Reg32(int idx = 0) : Reg32e(idx, 32) {} +}; +#ifdef XBYAK64 +struct Reg64 : public Reg32e { + explicit Reg64(int idx = 0) : Reg32e(idx, 64) {} +}; +struct RegRip { + sint64 disp_; + const Label* label_; + bool isAddr_; + explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {} + friend const RegRip operator+(const RegRip& r, int disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, int disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, const Label& label) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_, &label); + } + friend const RegRip operator+(const RegRip& r, const void *addr) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_ + (sint64)addr, 0, true); + } +}; +#endif + +inline Reg8 Reg::cvt8() const +{ + const int idx = getIdx(); + if (isBit(8)) return Reg8(idx, isExt8bit()); +#ifdef XBYAK32 + if (idx >= 4) throw Error(ERR_CANT_CONVERT); +#endif + return Reg8(idx, 4 <= idx && idx < 8); +} + +inline Reg16 Reg::cvt16() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg16(idx); +} + +inline Reg32 Reg::cvt32() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg32(idx); +} + +#ifdef XBYAK64 +inline Reg64 Reg::cvt64() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg64(idx); +} +#endif + +#ifndef XBYAK_DISABLE_SEGMENT +// not derived from Reg +class Segment { + int idx_; +public: + enum { + es, cs, ss, ds, fs, gs + }; + explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); } + int getIdx() const { return idx_; } + const char *toString() const + { + static const char tbl[][3] = { + "es", "cs", "ss", "ds", "fs", "gs" + }; + return tbl[idx_]; + } +}; +#endif + +class RegExp { +public: +#ifdef XBYAK64 + enum { i32e = 32 | 64 }; +#else + enum { i32e = 32 }; +#endif + RegExp(size_t disp = 0) : scale_(0), disp_(disp) { } + RegExp(const Reg& r, int scale = 1) + : scale_(scale) + , disp_(0) + { + if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (scale == 0) return; + if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE); + if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index + index_ = r; + } else { + base_ = r; + } + } + bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); } + RegExp optimize() const + { + RegExp exp = *this; + // [reg * 2] => [reg + reg] + if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) { + exp.base_ = index_; + exp.scale_ = 1; + } + return exp; + } + bool operator==(const RegExp& rhs) const + { + return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_; + } + const Reg& getBase() const { return base_; } + const Reg& getIndex() const { return index_; } + int getScale() const { return scale_; } + size_t getDisp() const { return disp_; } + void verify() const + { + if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (index_.getBit() && index_.getBit() <= 64) { + if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX); + if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER); + } + } + friend RegExp operator+(const RegExp& a, const RegExp& b); + friend RegExp operator-(const RegExp& e, size_t disp); + uint8 getRex() const + { + uint8 rex = index_.getRexX() | base_.getRexB(); + return rex ? uint8(rex | 0x40) : 0; + } +private: + /* + [base_ + index_ * scale_ + disp_] + base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm + */ + Reg base_; + Reg index_; + int scale_; + size_t disp_; +}; + +inline RegExp operator+(const RegExp& a, const RegExp& b) +{ + if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + RegExp ret = a; + if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; } + if (b.base_.getBit()) { + if (ret.base_.getBit()) { + if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + // base + base => base + index * 1 + ret.index_ = b.base_; + // [reg + esp] => [esp + reg] + if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_); + ret.scale_ = 1; + } else { + ret.base_ = b.base_; + } + } + ret.disp_ += b.disp_; + return ret; +} +inline RegExp operator*(const Reg& r, int scale) +{ + return RegExp(r, scale); +} +inline RegExp operator-(const RegExp& e, size_t disp) +{ + RegExp ret = e; + ret.disp_ -= disp; + return ret; +} + +// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc) +void *const AutoGrow = (void*)1; //-V566 +void *const DontSetProtectRWE = (void*)2; //-V566 + +class CodeArray { + enum Type { + USER_BUF = 1, // use userPtr(non alignment, non protect) + ALLOC_BUF, // use new(alignment, protect) + AUTO_GROW // automatically move and grow memory if necessary + }; + CodeArray(const CodeArray& rhs); + void operator=(const CodeArray&); + bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; } + struct AddrInfo { + size_t codeOffset; // position to write + size_t jmpAddr; // value to write + int jmpSize; // size of jmpAddr + inner::LabelMode mode; + AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode) + : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {} + uint64 getVal(const uint8 *top) const + { + uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top); + if (jmpSize == 4) disp = inner::VerifyInInt32(disp); + return disp; + } + }; + typedef std::list AddrInfoList; + AddrInfoList addrInfoList_; + const Type type_; +#ifdef XBYAK_USE_MMAP_ALLOCATOR + MmapAllocator defaultAllocator_; +#else + Allocator defaultAllocator_; +#endif + Allocator *alloc_; +protected: + size_t maxSize_; + uint8 *top_; + size_t size_; + bool isCalledCalcJmpAddress_; + + bool useProtect() const { return alloc_->useProtect(); } + /* + allocate new memory and copy old data to the new area + */ + void growMemory() + { + const size_t newSize = (std::max)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2); + uint8 *newTop = alloc_->alloc(newSize); + if (newTop == 0) throw Error(ERR_CANT_ALLOC); + for (size_t i = 0; i < size_; i++) newTop[i] = top_[i]; + alloc_->free(top_); + top_ = newTop; + maxSize_ = newSize; + } + /* + calc jmp address for AutoGrow mode + */ + void calcJmpAddress() + { + if (isCalledCalcJmpAddress_) return; + for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) { + uint64 disp = i->getVal(top_); + rewrite(i->codeOffset, disp, i->jmpSize); + } + isCalledCalcJmpAddress_ = true; + } +public: + enum ProtectMode { + PROTECT_RW = 0, // read/write + PROTECT_RWE = 1, // read/write/exec + PROTECT_RE = 2 // read/exec + }; + explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0) + : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF) + , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_) + , maxSize_(maxSize) + , top_(type_ == USER_BUF ? reinterpret_cast(userPtr) : alloc_->alloc((std::max)(maxSize, 1))) + , size_(0) + , isCalledCalcJmpAddress_(false) + { + if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC); + if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) { + alloc_->free(top_); + throw Error(ERR_CANT_PROTECT); + } + } + virtual ~CodeArray() + { + if (isAllocType()) { + if (useProtect()) setProtectModeRW(false); + alloc_->free(top_); + } + } + bool setProtectMode(ProtectMode mode, bool throwException = true) + { + bool isOK = protect(top_, maxSize_, mode); + if (isOK) return true; + if (throwException) throw Error(ERR_CANT_PROTECT); + return false; + } + bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); } + bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); } + void resetSize() + { + size_ = 0; + addrInfoList_.clear(); + isCalledCalcJmpAddress_ = false; + } + void db(int code) + { + if (size_ >= maxSize_) { + if (type_ == AUTO_GROW) { + growMemory(); + } else { + throw Error(ERR_CODE_IS_TOO_BIG); + } + } + top_[size_++] = static_cast(code); + } + void db(const uint8 *code, size_t codeSize) + { + for (size_t i = 0; i < codeSize; i++) db(code[i]); + } + void db(uint64 code, size_t codeSize) + { + if (codeSize > 8) throw Error(ERR_BAD_PARAMETER); + for (size_t i = 0; i < codeSize; i++) db(static_cast(code >> (i * 8))); + } + void dw(uint32 code) { db(code, 2); } + void dd(uint32 code) { db(code, 4); } + void dq(uint64 code) { db(code, 8); } + const uint8 *getCode() const { return top_; } + template + const F getCode() const { return reinterpret_cast(top_); } + const uint8 *getCurr() const { return &top_[size_]; } + template + const F getCurr() const { return reinterpret_cast(&top_[size_]); } + size_t getSize() const { return size_; } + void setSize(size_t size) + { + if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG); + size_ = size; + } + void dump() const + { + const uint8 *p = getCode(); + size_t bufSize = getSize(); + size_t remain = bufSize; + for (int i = 0; i < 4; i++) { + size_t disp = 16; + if (remain < 16) { + disp = remain; + } + for (size_t j = 0; j < 16; j++) { + if (j < disp) { + printf("%02X", p[i * 16 + j]); + } + } + putchar('\n'); + remain -= disp; + if (remain == 0) { + break; + } + } + } + /* + @param offset [in] offset from top + @param disp [in] offset from the next of jmp + @param size [in] write size(1, 2, 4, 8) + */ + void rewrite(size_t offset, uint64 disp, size_t size) + { + assert(offset < maxSize_); + if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER); + uint8 *const data = top_ + offset; + for (size_t i = 0; i < size; i++) { + data[i] = static_cast(disp >> (i * 8)); + } + } + void save(size_t offset, size_t val, int size, inner::LabelMode mode) + { + addrInfoList_.push_back(AddrInfo(offset, val, size, mode)); + } + bool isAutoGrow() const { return type_ == AUTO_GROW; } + bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; } + /** + change exec permission of memory + @param addr [in] buffer address + @param size [in] buffer size + @param protectMode [in] mode(RW/RWE/RE) + @return true(success), false(failure) + */ + static inline bool protect(const void *addr, size_t size, int protectMode) + { +#if defined(_WIN32) + const DWORD c_rw = PAGE_READWRITE; + const DWORD c_rwe = PAGE_EXECUTE_READWRITE; + const DWORD c_re = PAGE_EXECUTE_READ; + DWORD mode; +#else + const int c_rw = PROT_READ | PROT_WRITE; + const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC; + const int c_re = PROT_READ | PROT_EXEC; + int mode; +#endif + switch (protectMode) { + case PROTECT_RW: mode = c_rw; break; + case PROTECT_RWE: mode = c_rwe; break; + case PROTECT_RE: mode = c_re; break; + default: + return false; + } +#if defined(_WIN32) + DWORD oldProtect; + return VirtualProtect(const_cast(addr), size, mode, &oldProtect) != 0; +#elif defined(__GNUC__) + size_t pageSize = sysconf(_SC_PAGESIZE); + size_t iaddr = reinterpret_cast(addr); + size_t roundAddr = iaddr & ~(pageSize - static_cast(1)); +#ifndef NDEBUG + if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize); +#endif + return mprotect(reinterpret_cast(roundAddr), size + (iaddr - roundAddr), mode) == 0; +#else + return true; +#endif + } + /** + get aligned memory pointer + @param addr [in] address + @param alignedSize [in] power of two + @return aligned addr by alingedSize + */ + static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16) + { + return reinterpret_cast((reinterpret_cast(addr) + alignedSize - 1) & ~(alignedSize - static_cast(1))); + } +}; + +class Address : public Operand { +public: + enum Mode { + M_ModRM, + M_64bitDisp, + M_rip, + M_ripAddr + }; + Address(uint32 sizeBit, bool broadcast, const RegExp& e) + : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast) + { + e_.verify(); + } +#ifdef XBYAK64 + explicit Address(size_t disp) + : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ } + Address(uint32 sizeBit, bool broadcast, const RegRip& addr) + : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { } +#endif + RegExp getRegExp(bool optimize = true) const + { + return optimize ? e_.optimize() : e_; + } + Mode getMode() const { return mode_; } + bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; } + bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax + size_t getDisp() const { return e_.getDisp(); } + uint8 getRex() const + { + if (mode_ != M_ModRM) return 0; + return getRegExp().getRex(); + } + bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset + bool isBroadcast() const { return broadcast_; } + const Label* getLabel() const { return label_; } + bool operator==(const Address& rhs) const + { + return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_; + } + bool operator!=(const Address& rhs) const { return !operator==(rhs); } + bool isVsib() const { return e_.isVsib(); } +private: + RegExp e_; + const Label* label_; + Mode mode_; + bool broadcast_; +}; + +inline const Address& Operand::getAddress() const +{ + assert(isMEM()); + return static_cast(*this); +} + +inline bool Operand::operator==(const Operand& rhs) const +{ + if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress(); + return isEqualIfNotInherited(rhs); +} + +class AddressFrame { + void operator=(const AddressFrame&); + AddressFrame(const AddressFrame&); +public: + const uint32 bit_; + const bool broadcast_; + explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { } + Address operator[](const RegExp& e) const + { + return Address(bit_, broadcast_, e); + } + Address operator[](const void *disp) const + { + return Address(bit_, broadcast_, RegExp(reinterpret_cast(disp))); + } +#ifdef XBYAK64 + Address operator[](uint64 disp) const { return Address(disp); } + Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); } +#endif +}; + +struct JmpLabel { + size_t endOfJmp; /* offset from top to the end address of jmp */ + int jmpSize; + inner::LabelMode mode; + size_t disp; // disp for [rip + disp] + explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0) + : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp) + { + } +}; + +class LabelManager; + +class Label { + mutable LabelManager *mgr; + mutable int id; + friend class LabelManager; +public: + Label() : mgr(0), id(0) {} + Label(const Label& rhs); + Label& operator=(const Label& rhs); + ~Label(); + void clear() { mgr = 0; id = 0; } + int getId() const { return id; } + const uint8 *getAddress() const; + + // backward compatibility + static inline std::string toStr(int num) + { + char buf[16]; +#if defined(_MSC_VER) && (_MSC_VER < 1900) + _snprintf_s +#else + snprintf +#endif + (buf, sizeof(buf), ".%08x", num); + return buf; + } +}; + +class LabelManager { + // for string label + struct SlabelVal { + size_t offset; + SlabelVal(size_t offset) : offset(offset) {} + }; + typedef XBYAK_STD_UNORDERED_MAP SlabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP SlabelUndefList; + struct SlabelState { + SlabelDefList defList; + SlabelUndefList undefList; + }; + typedef std::list StateList; + // for Label class + struct ClabelVal { + ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {} + size_t offset; + int refCount; + }; + typedef XBYAK_STD_UNORDERED_MAP ClabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP ClabelUndefList; + typedef XBYAK_STD_UNORDERED_SET LabelPtrList; + + CodeArray *base_; + // global : stateList_.front(), local : stateList_.back() + StateList stateList_; + mutable int labelId_; + ClabelDefList clabelDefList_; + ClabelUndefList clabelUndefList_; + LabelPtrList labelPtrList_; + + int getId(const Label& label) const + { + if (label.id == 0) label.id = labelId_++; + return label.id; + } + template + void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset) + { + // add label + typename DefList::value_type item(labelId, addrOffset); + std::pair ret = defList.insert(item); + if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED); + // search undefined label + for (;;) { + typename UndefList::iterator itr = undefList.find(labelId); + if (itr == undefList.end()) break; + const JmpLabel *jmp = &itr->second; + const size_t offset = jmp->endOfJmp - jmp->jmpSize; + size_t disp; + if (jmp->mode == inner::LaddTop) { + disp = addrOffset; + } else if (jmp->mode == inner::Labs) { + disp = size_t(base_->getCurr()); + } else { + disp = addrOffset - jmp->endOfJmp + jmp->disp; +#ifdef XBYAK64 + if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR); + } + if (base_->isAutoGrow()) { + base_->save(offset, disp, jmp->jmpSize, jmp->mode); + } else { + base_->rewrite(offset, disp, jmp->jmpSize); + } + undefList.erase(itr); + } + } + template + bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const + { + typename DefList::const_iterator i = defList.find(label); + if (i == defList.end()) return false; + *offset = i->second.offset; + return true; + } + friend class Label; + void incRefCount(int id, Label *label) + { + clabelDefList_[id].refCount++; + labelPtrList_.insert(label); + } + void decRefCount(int id, Label *label) + { + labelPtrList_.erase(label); + ClabelDefList::iterator i = clabelDefList_.find(id); + if (i == clabelDefList_.end()) return; + if (i->second.refCount == 1) { + clabelDefList_.erase(id); + } else { + --i->second.refCount; + } + } + template + bool hasUndefinedLabel_inner(const T& list) const + { +#ifndef NDEBUG + for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) { + std::cerr << "undefined label:" << i->first << std::endl; + } +#endif + return !list.empty(); + } + // detach all labels linked to LabelManager + void resetLabelPtrList() + { + for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) { + (*i)->clear(); + } + labelPtrList_.clear(); + } +public: + LabelManager() + { + reset(); + } + ~LabelManager() + { + resetLabelPtrList(); + } + void reset() + { + base_ = 0; + labelId_ = 1; + stateList_.clear(); + stateList_.push_back(SlabelState()); + stateList_.push_back(SlabelState()); + clabelDefList_.clear(); + clabelUndefList_.clear(); + resetLabelPtrList(); + } + void enterLocal() + { + stateList_.push_back(SlabelState()); + } + void leaveLocal() + { + if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL); + if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND); + stateList_.pop_back(); + } + void set(CodeArray *base) { base_ = base; } + void defineSlabel(std::string label) + { + if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR); + if (label == "@@") { + SlabelDefList& defList = stateList_.front().defList; + SlabelDefList::iterator i = defList.find("@f"); + if (i != defList.end()) { + defList.erase(i); + label = "@b"; + } else { + i = defList.find("@b"); + if (i != defList.end()) { + defList.erase(i); + } + label = "@f"; + } + } + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + define_inner(st.defList, st.undefList, label, base_->getSize()); + } + void defineClabel(Label& label) + { + define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize()); + label.mgr = this; + labelPtrList_.insert(&label); + } + void assign(Label& dst, const Label& src) + { + ClabelDefList::const_iterator i = clabelDefList_.find(src.id); + if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L); + define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset); + dst.mgr = this; + labelPtrList_.insert(&dst); + } + bool getOffset(size_t *offset, std::string& label) const + { + const SlabelDefList& defList = stateList_.front().defList; + if (label == "@b") { + if (defList.find("@f") != defList.end()) { + label = "@f"; + } else if (defList.find("@b") == defList.end()) { + throw Error(ERR_LABEL_IS_NOT_FOUND); + } + } else if (label == "@f") { + if (defList.find("@f") != defList.end()) { + label = "@b"; + } + } + const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + return getOffset_inner(st.defList, offset, label); + } + bool getOffset(size_t *offset, const Label& label) const + { + return getOffset_inner(clabelDefList_, offset, getId(label)); + } + void addUndefinedLabel(const std::string& label, const JmpLabel& jmp) + { + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + st.undefList.insert(SlabelUndefList::value_type(label, jmp)); + } + void addUndefinedLabel(const Label& label, const JmpLabel& jmp) + { + clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp)); + } + bool hasUndefSlabel() const + { + for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) { + if (hasUndefinedLabel_inner(i->undefList)) return true; + } + return false; + } + bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); } + const uint8 *getCode() const { return base_->getCode(); } + bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); } +}; + +inline Label::Label(const Label& rhs) +{ + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); +} +inline Label& Label::operator=(const Label& rhs) +{ + if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L); + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); + return *this; +} +inline Label::~Label() +{ + if (id && mgr) mgr->decRefCount(id, this); +} +inline const uint8* Label::getAddress() const +{ + if (mgr == 0 || !mgr->isReady()) return 0; + size_t offset; + if (!mgr->getOffset(&offset, *this)) return 0; + return mgr->getCode() + offset; +} + +class CodeGenerator : public CodeArray { +public: + enum LabelType { + T_SHORT, + T_NEAR, + T_AUTO // T_SHORT if possible + }; +private: + CodeGenerator operator=(const CodeGenerator&); // don't call +#ifdef XBYAK64 + enum { i32e = 32 | 64, BIT = 64 }; + static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788; + typedef Reg64 NativeReg; +#else + enum { i32e = 32, BIT = 32 }; + static const size_t dummyAddr = 0x12345678; + typedef Reg32 NativeReg; +#endif + // (XMM, XMM|MEM) + static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isXMM() || op2.isMEM()); + } + // (MMX, MMX|MEM) or (XMM, XMM|MEM) + static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2) + { + return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2); + } + // (XMM, MMX|MEM) + static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isMMX() || op2.isMEM()); + } + // (MMX, XMM|MEM) + static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isMMX() && (op2.isXMM() || op2.isMEM()); + } + // (XMM, REG32|MEM) + static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM()); + } + // (REG32, XMM|MEM) + static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM()); + } + // (REG32, REG32|MEM) + static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM()); + } + void rex(const Operand& op1, const Operand& op2 = Operand()) + { + uint8 rex = 0; + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM()) std::swap(p1, p2); + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isMEM()) { + const Address& addr = p2->getAddress(); + if (BIT == 64 && addr.is32bit()) db(0x67); + rex = addr.getRex() | p1->getReg().getRex(); + } else { + // ModRM(reg, base); + rex = op2.getReg().getRex(op1.getReg()); + } + // except movsx(16bit, 32/64bit) + if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66); + if (rex) db(rex); + } + enum AVXtype { + // low 3 bit + T_N1 = 1, + T_N2 = 2, + T_N4 = 3, + T_N8 = 4, + T_N16 = 5, + T_N32 = 6, + T_NX_MASK = 7, + // + T_N_VL = 1 << 3, // N * (1, 2, 4) for VL + T_DUP = 1 << 4, // N = (8, 32, 64) + T_66 = 1 << 5, + T_F3 = 1 << 6, + T_F2 = 1 << 7, + T_0F = 1 << 8, + T_0F38 = 1 << 9, + T_0F3A = 1 << 10, + T_L0 = 1 << 11, + T_L1 = 1 << 12, + T_W0 = 1 << 13, + T_W1 = 1 << 14, + T_EW0 = 1 << 15, + T_EW1 = 1 << 16, + T_YMM = 1 << 17, // support YMM, ZMM + T_EVEX = 1 << 18, + T_ER_X = 1 << 19, // xmm{er} + T_ER_Y = 1 << 20, // ymm{er} + T_ER_Z = 1 << 21, // zmm{er} + T_SAE_X = 1 << 22, // xmm{sae} + T_SAE_Y = 1 << 23, // ymm{sae} + T_SAE_Z = 1 << 24, // zmm{sae} + T_MUST_EVEX = 1 << 25, // contains T_EVEX + T_B32 = 1 << 26, // m32bcst + T_B64 = 1 << 27, // m64bcst + T_M_K = 1 << 28, // mem{k} + T_VSIB = 1 << 29, + T_MEM_EVEX = 1 << 30, // use evex if mem + T_XXX + }; + void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false) + { + int w = (type & T_W1) ? 1 : 0; + bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM(); + bool r = reg.isExtIdx(); + bool b = base.isExtIdx(); + int idx = v ? v->getIdx() : 0; + if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION); + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp; + if (!b && !x && !w && (type & T_0F)) { + db(0xC5); db((r ? 0 : 0x80) | vvvv); + } else { + uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv); + } + db(code); + } + void verifySAE(const Reg& r, int type) const + { + if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return; + throw Error(ERR_SAE_IS_INVALID); + } + void verifyER(const Reg& r, int type) const + { + if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return; + throw Error(ERR_ER_IS_INVALID); + } + // (a, b, c) contains non zero two or three values then err + int verifyDuplicate(int a, int b, int c, int err) + { + int v = a | b | c; + if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err); + return v; + } + int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false) + { + if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID); + int w = (type & T_EW1) ? 1 : 0; + uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + + int idx = v ? v->getIdx() : 0; + uint32 vvvv = ~idx; + + bool R = !reg.isExtIdx(); + bool X = x ? false : !base.isExtIdx2(); + bool B = !base.isExtIdx(); + bool Rp = !reg.isExtIdx2(); + int LL; + int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET); + int disp8N = 1; + if (rounding) { + if (rounding == EvexModifierRounding::T_SAE) { + verifySAE(base, type); LL = 0; + } else { + verifyER(base, type); LL = rounding - 1; + } + b = true; + } else { + if (v) VL = (std::max)(VL, v->getBit()); + VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL); + LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0; + if (b) { + disp8N = (type & T_B32) ? 4 : 8; + } else if (type & T_DUP) { + disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64; + } else { + if ((type & (T_NX_MASK | T_N_VL)) == 0) { + type |= T_N16 | T_N_VL; // default + } + int low = type & T_NX_MASK; + if (low > 0) { + disp8N = 1 << (low - 1); + if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1); + } + } + } + bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx); + bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false); + if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET); + db(0x62); + db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3)); + db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3)); + db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7)); + db(code); + return disp8N; + } + void setModRM(int mod, int r1, int r2) + { + db(static_cast((mod << 6) | ((r1 & 7) << 3) | (r2 & 7))); + } + void setSIB(const RegExp& e, int reg, int disp8N = 0) + { + size_t disp64 = e.getDisp(); +#ifdef XBYAK64 + size_t high = disp64 >> 32; + if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + uint32 disp = static_cast(disp64); + const Reg& base = e.getBase(); + const Reg& index = e.getIndex(); + const int baseIdx = base.getIdx(); + const int baseBit = base.getBit(); + const int indexBit = index.getBit(); + enum { + mod00 = 0, mod01 = 1, mod10 = 2 + }; + int mod = mod10; // disp32 + if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) { + mod = mod00; + } else { + if (disp8N == 0) { + if (inner::IsInDisp8(disp)) { + mod = mod01; + } + } else { + // disp must be casted to signed + uint32 t = static_cast(static_cast(disp) / disp8N); + if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) { + disp = t; + mod = mod01; + } + } + } + const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP; + /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */ + bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP; +#ifdef XBYAK64 + if (!baseBit && !indexBit) hasSIB = true; +#endif + if (hasSIB) { + setModRM(mod, reg, Operand::ESP); + /* SIB = [2:3:3] = [SS:index:base(=rm)] */ + const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP; + const int scale = e.getScale(); + const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0; + setModRM(SS, idx, newBaseIdx); + } else { + setModRM(mod, reg, newBaseIdx); + } + if (mod == mod01) { + db(disp); + } else if (mod == mod10 || (mod == mod00 && !baseBit)) { + dd(disp); + } + } + LabelManager labelMgr_; + bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; } + void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE) + { + rex(reg2, reg1); + db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + setModRM(3, reg1.getIdx(), reg2.getIdx()); + } + void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + rex(addr, reg); + db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + opAddr(addr, reg.getIdx(), immSize); + } + void opMIB(const Address& addr, const Reg& reg, int code0, int code1) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS); + if (BIT == 64 && addr.is32bit()) db(0x67); + const RegExp& regExp = addr.getRegExp(false); + uint8 rex = regExp.getRex(); + if (rex) db(rex); + db(code0); db(code1); + setSIB(regExp, reg.getIdx()); + } + void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + const int shortJmpSize = 2; + const int longHeaderSize = longPref ? 2 : 1; + const int longJmpSize = longHeaderSize + 4; + if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) { + db(shortCode); db(disp - shortJmpSize); + } else { + if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR); + if (longPref) db(longPref); + db(longCode); dd(disp - longJmpSize); + } + } + template + void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */ + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { /* label exists */ + makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref); + } else { + int jmpSize = 0; + if (type == T_NEAR) { + jmpSize = 4; + if (longPref) db(longPref); + db(longCode); dd(0); + } else { + jmpSize = 1; + db(shortCode); db(0); + } + JmpLabel jmp(size_, jmpSize, inner::LasIs); + labelMgr_.addUndefinedLabel(label, jmp); + } + } + void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0) + { + if (isAutoGrow()) { + if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW); + if (size_ + 16 >= maxSize_) growMemory(); + if (longPref) db(longPref); + db(longCode); + dd(0); + save(size_ - 4, size_t(addr) - size_, 4, inner::Labs); + } else { + makeJmp(inner::VerifyInInt32(reinterpret_cast(addr) - getCurr()), type, shortCode, longCode, longPref); + } + + } + // reg is reg field of ModRM + // immSize is the size for immediate value + // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement + void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false) + { + if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING); + if (addr.getMode() == Address::M_ModRM) { + setSIB(addr.getRegExp(), reg, disp8N); + } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) { + setModRM(0, reg, 5); + if (addr.getLabel()) { // [rip + Label] + putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize); + } else { + size_t disp = addr.getDisp(); + if (addr.getMode() == Address::M_ripAddr) { + if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW); + disp -= (size_t)getCurr() + 4 + immSize; + } + dd(inner::VerifyInInt32(disp)); + } + } + } + /* preCode is for SSSE3/SSE4 */ + void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE) + { + if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION); + if (pref != NONE) db(pref); + if (op.isMEM()) { + opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0); + } else { + opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code); + } + if (imm8 != NONE) db(imm8); + } + void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext) + { + if (mmx.isXMM()) db(0x66); + opModR(Reg32(ext), mmx, 0x0F, code); + db(imm8); + } + void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE) + { + opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode); + } + void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref) + { + if (pref != NONE) db(pref); + if (op1.isXMM() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), 0x0F, code); + } else if (op1.isMEM() && op2.isXMM()) { + opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false) + { + if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */ + if (mmx.isXMM()) db(0x66); + opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm); + } else { + opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A); + } + } + void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0) + { + int opBit = op.getBit(); + if (disableRex && opBit == 64) opBit = 32; + if (op.isREG(bit)) { + opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2); + } else if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShift(const Operand& op, int imm, int ext) + { + verifyMemHasSize(op); + opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0); + if (imm != 1) db(imm); + } + void opShift(const Operand& op, const Reg8& _cl, int ext) + { + if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opR_ModM(op, 0, ext, 0xD2); + } + void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (condR) { + opModR(op1.getReg(), op2.getReg(), code0, code1, code2); + } else if (condM) { + opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0) + { + if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1); + if (!_cl) db(imm); + } + // (REG, REG|MEM), (MEM, REG) + void opRM_RM(const Operand& op1, const Operand& op2, int code) + { + if (op1.isREG() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), code | 2); + } else { + opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code); + } + } + // (REG|MEM, IMM) + void opRM_I(const Operand& op, uint32 imm, int code, int ext) + { + verifyMemHasSize(op); + uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32; + if (op.isBit(8)) immBit = 8; + if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG); + if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */ + if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al + rex(op); + db(code | 4 | (immBit == 8 ? 0 : 1)); + } else { + int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0; + opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8); + } + db(imm, immBit / 8); + } + void opIncDec(const Operand& op, int code, int ext) + { + verifyMemHasSize(op); +#ifndef XBYAK64 + if (op.isREG() && !op.isBit(8)) { + rex(op); db(code | op.getIdx()); + return; + } +#endif + code = 0xFE; + if (op.isREG()) { + opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code); + } else { + opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code); + } + } + void opPushPop(const Operand& op, int code, int ext, int alt) + { + int bit = op.getBit(); + if (bit == 16 || bit == BIT) { + if (bit == 16) db(0x66); + if (op.isREG()) { + if (op.getReg().getIdx() >= 8) db(0x41); + db(alt | (op.getIdx() & 7)); + return; + } + if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code); + return; + } + } + throw Error(ERR_BAD_COMBINATION); + } + void verifyMemHasSize(const Operand& op) const + { + if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED); + } + /* + mov(r, imm) = db(imm, mov_imm(r, imm)) + */ + int mov_imm(const Reg& reg, size_t imm) + { + int bit = reg.getBit(); + const int idx = reg.getIdx(); + int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3); + if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) { + rex(Reg32(idx)); + bit = 32; + } else { + rex(reg); + if (bit == 64 && inner::IsInInt32(imm)) { + db(0xC7); + code = 0xC0; + bit = 32; + } + } + db(code | (idx & 7)); + return bit / 8; + } + template + void putL_inner(T& label, bool relative = false, size_t disp = 0) + { + const int jmpSize = relative ? 4 : (int)sizeof(size_t); + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { + if (relative) { + db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize); + } else if (isAutoGrow()) { + db(uint64(0), jmpSize); + save(size_ - jmpSize, offset, jmpSize, inner::LaddTop); + } else { + db(size_t(top_) + offset, jmpSize); + } + return; + } + db(uint64(0), jmpSize); + JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp); + labelMgr_.addUndefinedLabel(label, jmp); + } + void opMovxx(const Reg& reg, const Operand& op, uint8 code) + { + if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION); + int w = op.isBit(16); +#ifdef XBYAK64 + if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION); +#endif + bool cond = reg.isREG() && (reg.getBit() > op.getBit()); + opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w); + } + void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0; + if (!code) throw Error(ERR_BAD_MEM_SIZE); + if (m64ext && addr.isBit(64)) ext = m64ext; + + rex(addr, st0); + db(code); + opAddr(addr, ext); + } + // use code1 if reg1 == st0 + // use code2 if reg1 != st0 && reg2 == st0 + void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2) + { + uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0; + if (!code) throw Error(ERR_BAD_ST_COMBINATION); + db(uint8(code >> 8)); + db(uint8(code | (reg1.getIdx() | reg2.getIdx()))); + } + void opFpu(const Fpu& reg, uint8 code1, uint8 code2) + { + db(code1); db(code2 | reg.getIdx()); + } + void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE) + { + if (op2.isMEM()) { + const Address& addr = op2.getAddress(); + const RegExp& regExp = addr.getRegExp(); + const Reg& base = regExp.getBase(); + const Reg& index = regExp.getIndex(); + if (BIT == 64 && addr.is32bit()) db(0x67); + int disp8N = 0; + bool x = index.isExtIdx(); + if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) { + int aaa = addr.getOpmaskIdx(); + if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY); + bool b = false; + if (addr.isBroadcast()) { + if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST); + b = true; + } + int VL = regExp.isVsib() ? index.getBit() : 0; + disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2()); + } else { + vex(r, base, p1, type, code, x); + } + opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0); + } else { + const Reg& base = op2.getReg(); + if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) { + evex(r, base, p1, type, code); + } else { + vex(r, base, p1, type, code); + } + setModRM(3, r.getIdx(), base.getIdx()); + } + if (imm8 != NONE) db(imm8); + } + // (r, r, r/m) if isR_R_RM + // (r, r/m, r) + void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE) + { + const Operand *p1 = &op1; + const Operand *p2 = &op2; + if (!isR_R_RM) std::swap(p1, p2); + const unsigned int bit = r.getBit(); + if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION); + type |= (bit == 64) ? T_W1 : T_W0; + opVex(r, p1, *p2, type, code, imm8); + } + void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE) + { + const Xmm *x2 = static_cast(&op1); + const Operand *op = &op2; + if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1) + x2 = &x1; + op = &op1; + } + // (x1, x2, op) + if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION); + opVex(x1, x2, *op, type, code0, imm8); + } + void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE) + { + if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION); + opVex(k, &x2, op3, type, code0, imm8); + } + // (x, x/m), (y, x/m256), (z, y/m) + void checkCvt1(const Operand& x, const Operand& op) const + { + if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION); + } + // (x, x/m), (x, y/m256), (y, z/m) + void checkCvt2(const Xmm& x, const Operand& op) const + { + if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION); + } + void opCvt2(const Xmm& x, const Operand& op, int type, int code) + { + checkCvt2(x, op); + Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM; + opVex(x.copyAndSetKind(kind), &xm0, op, type, code); + } + void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code) + { + if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER); + Xmm x(op.getIdx()); + const Operand *p = op.isREG() ? &x : &op; + opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code); + } + const Xmm& cvtIdx0(const Operand& x) const + { + return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0; + } + // support (x, x/m, imm), (y, y/m, imm) + void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE) + { + opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8); + } + // QQQ:need to refactor + void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1) + { + if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM()); + if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); + if (is16bit) db(0x66); + db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1); + } + void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode) + { + const RegExp& regExp = addr.getRegExp(); + if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING); + const int y_vx_y = 0; + const int y_vy_y = 1; +// const int x_vy_x = 2; + const bool isAddrYMM = regExp.getIndex().getBit() == 256; + if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) { + bool isOK = false; + if (mode == y_vx_y) { + isOK = x1.isYMM() && !isAddrYMM && x2.isYMM(); + } else if (mode == y_vy_y) { + isOK = x1.isYMM() && isAddrYMM && x2.isYMM(); + } else { // x_vy_x + isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM(); + } + if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING); + } + opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code); + } + enum { + xx_yy_zz = 0, + xx_yx_zy = 1, + xx_xy_yz = 2 + }; + void checkGather2(const Xmm& x1, const Reg& x2, int mode) const + { + if (x1.isXMM() && x2.isXMM()) return; + switch (mode) { + case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return; + break; + case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return; + break; + case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return; + break; + } + throw Error(ERR_BAD_VSIB_ADDRESSING); + } + void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode) + { + if (x.hasZero()) throw Error(ERR_INVALID_ZERO); + checkGather2(x, addr.getRegExp().getIndex(), mode); + opVex(x, 0, addr, type, code); + } + /* + xx_xy_yz ; mode = true + xx_xy_xz ; mode = false + */ + void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode) + { + if (mode) { + if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION); + } else { + if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); + } + opVex(x, 0, op, type, code); + } + void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind) + { + if (addr.hasZero()) throw Error(ERR_INVALID_ZERO); + if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING); + opVex(x, 0, addr, type, code); + } +public: + unsigned int getVersion() const { return VERSION; } + using CodeArray::db; + const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7; + const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7; + const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7; + const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7; + const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7; + const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi; + const Reg16 ax, cx, dx, bx, sp, bp, si, di; + const Reg8 al, cl, dl, bl, ah, ch, dh, bh; + const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM + const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b} + const Fpu st0, st1, st2, st3, st4, st5, st6, st7; + const Opmask k0, k1, k2, k3, k4, k5, k6, k7; + const BoundsReg bnd0, bnd1, bnd2, bnd3; + const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae} + const EvexModifierZero T_z; // {z} +#ifdef XBYAK64 + const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15; + const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d; + const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w; + const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b; + const Reg8 spl, bpl, sil, dil; + const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15; + const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23; + const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31; + const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15; + const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23; + const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15; + const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23; + const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience + const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23; + const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31; + const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15; + const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23; + const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31; + const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15; + const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23; + const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31; + const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT + const Segment es, cs, ss, ds, fs, gs; +#endif + void L(const std::string& label) { labelMgr_.defineSlabel(label); } + void L(Label& label) { labelMgr_.defineClabel(label); } + Label L() { Label label; L(label); return label; } + void inLocalLabel() { labelMgr_.enterLocal(); } + void outLocalLabel() { labelMgr_.leaveLocal(); } + /* + assign src to dst + require + dst : does not used by L() + src : used by L() + */ + void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); } + /* + put address of label to buffer + @note the put size is 4(32-bit), 8(64-bit) + */ + void putL(std::string label) { putL_inner(label); } + void putL(const Label& label) { putL_inner(label); } + + void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); } + void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); } + void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); } + + void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); } + // call(string label), not const std::string& + void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + void call(const char *label) { call(std::string(label)); } + void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + // call(function pointer) +#ifdef XBYAK_VARIADIC_TEMPLATE + template + void call(Ret(*func)(Params...)) { call(reinterpret_cast(func)); } +#endif + void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); } + + void test(const Operand& op, const Reg& reg) + { + opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84); + } + void test(const Operand& op, uint32 imm) + { + verifyMemHasSize(op); + int immSize = (std::min)(op.getBit() / 8, 4U); + if (op.isREG() && op.getIdx() == 0) { // al, ax, eax + rex(op); + db(0xA8 | (op.isBit(8) ? 0 : 1)); + } else { + opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize); + } + db(imm, immSize); + } + void imul(const Reg& reg, const Operand& op) + { + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF); + } + void imul(const Reg& reg, const Operand& op, int imm) + { + int s = inner::IsInDisp8(imm) ? 1 : 0; + int immSize = s ? 1 : reg.isREG(16) ? 2 : 4; + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize); + db(imm, immSize); + } + void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); } + void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); } + void push(const AddressFrame& af, uint32 imm) + { + if (af.bit_ == 8 && inner::IsInDisp8(imm)) { + db(0x6A); db(imm); + } else if (af.bit_ == 16 && isInDisp16(imm)) { + db(0x66); db(0x68); dw(imm); + } else { + db(0x68); dd(imm); + } + } + /* use "push(word, 4)" if you want "push word 4" */ + void push(uint32 imm) + { + if (inner::IsInDisp8(imm)) { + push(byte, imm); + } else { + push(dword, imm); + } + } + void mov(const Operand& reg1, const Operand& reg2) + { + const Reg *reg = 0; + const Address *addr = 0; + uint8 code = 0; + if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp] + reg = ®1.getReg(); + addr= ®2.getAddress(); + code = 0xA0; + } else + if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al + reg = ®2.getReg(); + addr= ®1.getAddress(); + code = 0xA2; + } +#ifdef XBYAK64 + if (addr && addr->is64bitDisp()) { + if (code) { + rex(*reg); + db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3); + db(addr->getDisp(), 8); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } else +#else + if (code && addr->isOnlyDisp()) { + rex(*reg, *addr); + db(code | (reg->isBit(8) ? 0 : 1)); + dd(static_cast(addr->getDisp())); + } else +#endif + { + opRM_RM(reg1, reg2, 0x88); + } + } + void mov(const Operand& op, size_t imm) + { + if (op.isREG()) { + const int size = mov_imm(op.getReg(), imm); + db(imm, size); + } else if (op.isMEM()) { + verifyMemHasSize(op); + int immSize = op.getBit() / 8; + if (immSize <= 4) { + sint64 s = sint64(imm) >> (immSize * 8); + if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG); + } else { + if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG); + immSize = 4; + } + opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize); + db(static_cast(imm), immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void mov(const NativeReg& reg, const char *label) // can't use std::string + { + if (label == 0) { + mov(static_cast(reg), 0); // call imm + return; + } + mov_imm(reg, dummyAddr); + putL(label); + } + void mov(const NativeReg& reg, const Label& label) + { + mov_imm(reg, dummyAddr); + putL(label); + } + void xchg(const Operand& op1, const Operand& op2) + { + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) { + p1 = &op2; p2 = &op1; + } + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0) +#ifdef XBYAK64 + && (p2->getIdx() != 0 || !p1->isREG(32)) +#endif + ) { + rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7)); + return; + } + opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1)); + } + +#ifndef XBYAK_DISABLE_SEGMENT + void push(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x06); break; + case Segment::cs: db(0x0E); break; + case Segment::ss: db(0x16); break; + case Segment::ds: db(0x1E); break; + case Segment::fs: db(0x0F); db(0xA0); break; + case Segment::gs: db(0x0F); db(0xA8); break; + default: + assert(0); + } + } + void pop(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x07); break; + case Segment::cs: throw Error(ERR_BAD_COMBINATION); + case Segment::ss: db(0x17); break; + case Segment::ds: db(0x1F); break; + case Segment::fs: db(0x0F); db(0xA1); break; + case Segment::gs: db(0x0F); db(0xA9); break; + default: + assert(0); + } + } + void putSeg(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x2E); break; + case Segment::cs: db(0x36); break; + case Segment::ss: db(0x3E); break; + case Segment::ds: db(0x26); break; + case Segment::fs: db(0x64); break; + case Segment::gs: db(0x65); break; + default: + assert(0); + } + } + void mov(const Operand& op, const Segment& seg) + { + opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C); + } + void mov(const Segment& seg, const Operand& op) + { + opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E); + } +#endif + + enum { NONE = 256 }; + // constructor + CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0) + : CodeArray(maxSize, userPtr, allocator) + , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7) + , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7) + , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7) + , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7) + // for my convenience + , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7) + , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7) + , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7) + + , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI) + , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI) + , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH) + , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512) + , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true) + , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7) + , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7) + , bnd0(0), bnd1(1), bnd2(2), bnd3(3) + , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE) + , T_z() +#ifdef XBYAK64 + , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15) + , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15) + , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15) + , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15) + , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true) + , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15) + , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23) + , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31) + , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15) + , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23) + , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31) + , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15) + , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23) + , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31) + // for my convenience + , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15) + , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23) + , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31) + , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15) + , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23) + , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31) + , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15) + , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23) + , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31) + , rip() +#endif +#ifndef XBYAK_DISABLE_SEGMENT + , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs) +#endif + { + labelMgr_.set(this); + } + void reset() + { + resetSize(); + labelMgr_.reset(); + labelMgr_.set(this); + } + bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); } + /* + MUST call ready() to complete generating code if you use AutoGrow mode. + It is not necessary for the other mode if hasUndefinedLabel() is true. + */ + void ready(ProtectMode mode = PROTECT_RWE) + { + if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND); + if (isAutoGrow()) { + calcJmpAddress(); + if (useProtect()) setProtectMode(mode); + } + } + // set read/exec + void readyRE() { return ready(PROTECT_RE); } +#ifdef XBYAK_TEST + void dump(bool doClear = true) + { + CodeArray::dump(); + if (doClear) size_ = 0; + } +#endif + +#ifdef XBYAK_UNDEF_JNL + #undef jnl +#endif + + /* + use single byte nop if useMultiByteNop = false + */ + void nop(size_t size = 1, bool useMultiByteNop = true) + { + if (!useMultiByteNop) { + for (size_t i = 0; i < size; i++) { + db(0x90); + } + return; + } + /* + Intel Architectures Software Developer's Manual Volume 2 + recommended multi-byte sequence of NOP instruction + AMD and Intel seem to agree on the same sequences for up to 9 bytes: + https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf + */ + static const uint8 nopTbl[9][9] = { + {0x90}, + {0x66, 0x90}, + {0x0F, 0x1F, 0x00}, + {0x0F, 0x1F, 0x40, 0x00}, + {0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00}, + {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + }; + const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]); + while (size > 0) { + size_t len = (std::min)(n, size); + const uint8 *seq = nopTbl[len - 1]; + db(seq, len); + size -= len; + } + } + +#ifndef XBYAK_DONT_READ_LIST +#include "xbyak_mnemonic.h" + /* + use single byte nop if useMultiByteNop = false + */ + void align(size_t x = 16, bool useMultiByteNop = true) + { + if (x == 1) return; + if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN); + if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x); + size_t remain = size_t(getCurr()) % x; + if (remain) { + nop(x - remain, useMultiByteNop); + } + } +#endif +}; + +namespace util { +static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7); +static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7); +static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7); +static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7); +static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI); +static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI); +static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH); +static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512); +static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true); +static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7); +static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7); +static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3); +static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE); +static const EvexModifierZero T_z; +#ifdef XBYAK64 +static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15); +static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15); +static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15); +static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true); +static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15); +static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23); +static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31); +static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15); +static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23); +static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31); +static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15); +static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23); +static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31); +static const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT +static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs); +#endif +} // util + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +} // end of namespace + +#endif // XBYAK_XBYAK_H_ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h new file mode 100644 index 0000000000..a22e5224c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h @@ -0,0 +1,303 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +enum { + B00000000= 0, + B00000001= 1, + B00000010= 2, + B00000011= 3, + B00000100= 4, + B00000101= 5, + B00000110= 6, + B00000111= 7, + B00001000= 8, + B00001001= 9, + B00001010= 10, + B00001011= 11, + B00001100= 12, + B00001101= 13, + B00001110= 14, + B00001111= 15, + B00010000= 16, + B00010001= 17, + B00010010= 18, + B00010011= 19, + B00010100= 20, + B00010101= 21, + B00010110= 22, + B00010111= 23, + B00011000= 24, + B00011001= 25, + B00011010= 26, + B00011011= 27, + B00011100= 28, + B00011101= 29, + B00011110= 30, + B00011111= 31, + B00100000= 32, + B00100001= 33, + B00100010= 34, + B00100011= 35, + B00100100= 36, + B00100101= 37, + B00100110= 38, + B00100111= 39, + B00101000= 40, + B00101001= 41, + B00101010= 42, + B00101011= 43, + B00101100= 44, + B00101101= 45, + B00101110= 46, + B00101111= 47, + B00110000= 48, + B00110001= 49, + B00110010= 50, + B00110011= 51, + B00110100= 52, + B00110101= 53, + B00110110= 54, + B00110111= 55, + B00111000= 56, + B00111001= 57, + B00111010= 58, + B00111011= 59, + B00111100= 60, + B00111101= 61, + B00111110= 62, + B00111111= 63, + B01000000= 64, + B01000001= 65, + B01000010= 66, + B01000011= 67, + B01000100= 68, + B01000101= 69, + B01000110= 70, + B01000111= 71, + B01001000= 72, + B01001001= 73, + B01001010= 74, + B01001011= 75, + B01001100= 76, + B01001101= 77, + B01001110= 78, + B01001111= 79, + B01010000= 80, + B01010001= 81, + B01010010= 82, + B01010011= 83, + B01010100= 84, + B01010101= 85, + B01010110= 86, + B01010111= 87, + B01011000= 88, + B01011001= 89, + B01011010= 90, + B01011011= 91, + B01011100= 92, + B01011101= 93, + B01011110= 94, + B01011111= 95, + B01100000= 96, + B01100001= 97, + B01100010= 98, + B01100011= 99, + B01100100= 100, + B01100101= 101, + B01100110= 102, + B01100111= 103, + B01101000= 104, + B01101001= 105, + B01101010= 106, + B01101011= 107, + B01101100= 108, + B01101101= 109, + B01101110= 110, + B01101111= 111, + B01110000= 112, + B01110001= 113, + B01110010= 114, + B01110011= 115, + B01110100= 116, + B01110101= 117, + B01110110= 118, + B01110111= 119, + B01111000= 120, + B01111001= 121, + B01111010= 122, + B01111011= 123, + B01111100= 124, + B01111101= 125, + B01111110= 126, + B01111111= 127, + B10000000= 128, + B10000001= 129, + B10000010= 130, + B10000011= 131, + B10000100= 132, + B10000101= 133, + B10000110= 134, + B10000111= 135, + B10001000= 136, + B10001001= 137, + B10001010= 138, + B10001011= 139, + B10001100= 140, + B10001101= 141, + B10001110= 142, + B10001111= 143, + B10010000= 144, + B10010001= 145, + B10010010= 146, + B10010011= 147, + B10010100= 148, + B10010101= 149, + B10010110= 150, + B10010111= 151, + B10011000= 152, + B10011001= 153, + B10011010= 154, + B10011011= 155, + B10011100= 156, + B10011101= 157, + B10011110= 158, + B10011111= 159, + B10100000= 160, + B10100001= 161, + B10100010= 162, + B10100011= 163, + B10100100= 164, + B10100101= 165, + B10100110= 166, + B10100111= 167, + B10101000= 168, + B10101001= 169, + B10101010= 170, + B10101011= 171, + B10101100= 172, + B10101101= 173, + B10101110= 174, + B10101111= 175, + B10110000= 176, + B10110001= 177, + B10110010= 178, + B10110011= 179, + B10110100= 180, + B10110101= 181, + B10110110= 182, + B10110111= 183, + B10111000= 184, + B10111001= 185, + B10111010= 186, + B10111011= 187, + B10111100= 188, + B10111101= 189, + B10111110= 190, + B10111111= 191, + B11000000= 192, + B11000001= 193, + B11000010= 194, + B11000011= 195, + B11000100= 196, + B11000101= 197, + B11000110= 198, + B11000111= 199, + B11001000= 200, + B11001001= 201, + B11001010= 202, + B11001011= 203, + B11001100= 204, + B11001101= 205, + B11001110= 206, + B11001111= 207, + B11010000= 208, + B11010001= 209, + B11010010= 210, + B11010011= 211, + B11010100= 212, + B11010101= 213, + B11010110= 214, + B11010111= 215, + B11011000= 216, + B11011001= 217, + B11011010= 218, + B11011011= 219, + B11011100= 220, + B11011101= 221, + B11011110= 222, + B11011111= 223, + B11100000= 224, + B11100001= 225, + B11100010= 226, + B11100011= 227, + B11100100= 228, + B11100101= 229, + B11100110= 230, + B11100111= 231, + B11101000= 232, + B11101001= 233, + B11101010= 234, + B11101011= 235, + B11101100= 236, + B11101101= 237, + B11101110= 238, + B11101111= 239, + B11110000= 240, + B11110001= 241, + B11110010= 242, + B11110011= 243, + B11110100= 244, + B11110101= 245, + B11110110= 246, + B11110111= 247, + B11111000= 248, + B11111001= 249, + B11111010= 250, + B11111011= 251, + B11111100= 252, + B11111101= 253, + B11111110= 254, + B11111111= 255 +}; diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h new file mode 100644 index 0000000000..28d2d222f9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h @@ -0,0 +1,2017 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +const char *getVersionString() const { return "5.76"; } +void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); } +void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); } +void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); } +void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); } +void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); } +void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); } +void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); } +void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); } +void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); } +void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); } +void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); } +void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); } +void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); } +void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); } +void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); } +void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); } +void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); } +void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); } +void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); } +void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); } +void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); } +void bnd() { db(0xF2); } +void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); } +void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); } +void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); } +void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); } +void bsf(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); } +void bsr(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); } +void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); } +void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); } +void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); } +void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); } +void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); } +void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); } +void cbw() { db(0x66); db(0x98); } +void cdq() { db(0x99); } +void clc() { db(0xF8); } +void cld() { db(0xFC); } +void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); } +void cli() { db(0xFA); } +void cmc() { db(0xF5); } +void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524 +void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524 +void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524 +void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524 +void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); } +void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); } +void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); } +void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); } +void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); } +void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); } +void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); } +void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); } +void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); } +void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); } +void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); } +void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); } +void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); } +void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); } +void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); } +void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); } +void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); } +void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); } +void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); } +void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); } +void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); } +void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); } +void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); } +void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); } +void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); } +void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); } +void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); } +void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); } +void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); } +void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); } +void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); } +void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); } +void cmpsb() { db(0xA6); } +void cmpsd() { db(0xA7); } +void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); } +void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); } +void cmpsw() { db(0x66); db(0xA7); } +void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); } +void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); } +void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); } +void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); } +void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); } +void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); } +void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); } +void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); } +void cpuid() { db(0x0F); db(0xA2); } +void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); } +void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); } +void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); } +void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); } +void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); } +void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); } +void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); } +void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); } +void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); } +void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); } +void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); } +void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); } +void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); } +void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); } +void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); } +void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); } +void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); } +void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); } +void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); } +void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); } +void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); } +void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); } +void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); } +void cwd() { db(0x66); db(0x99); } +void cwde() { db(0x98); } +void dec(const Operand& op) { opIncDec(op, 0x48, 1); } +void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); } +void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); } +void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); } +void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); } +void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); } +void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void emms() { db(0x0F); db(0x77); } +void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); } +void f2xm1() { db(0xD9); db(0xF0); } +void fabs() { db(0xD9); db(0xE1); } +void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); } +void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); } +void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); } +void faddp() { db(0xDE); db(0xC1); } +void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); } +void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); } +void fchs() { db(0xD9); db(0xE0); } +void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); } +void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); } +void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); } +void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); } +void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); } +void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); } +void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); } +void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); } +void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); } +void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); } +void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); } +void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); } +void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); } +void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); } +void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); } +void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); } +void fcom() { db(0xD8); db(0xD1); } +void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); } +void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); } +void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); } +void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); } +void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); } +void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); } +void fcomp() { db(0xD8); db(0xD9); } +void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); } +void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); } +void fcompp() { db(0xDE); db(0xD9); } +void fcos() { db(0xD9); db(0xFF); } +void fdecstp() { db(0xD9); db(0xF6); } +void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); } +void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); } +void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); } +void fdivp() { db(0xDE); db(0xF9); } +void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); } +void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); } +void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); } +void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); } +void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); } +void fdivrp() { db(0xDE); db(0xF1); } +void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); } +void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); } +void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); } +void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); } +void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); } +void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); } +void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); } +void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); } +void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); } +void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); } +void fincstp() { db(0xD9); db(0xF7); } +void finit() { db(0x9B); db(0xDB); db(0xE3); } +void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); } +void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); } +void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); } +void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); } +void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); } +void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); } +void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); } +void fld1() { db(0xD9); db(0xE8); } +void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); } +void fldl2e() { db(0xD9); db(0xEA); } +void fldl2t() { db(0xD9); db(0xE9); } +void fldlg2() { db(0xD9); db(0xEC); } +void fldln2() { db(0xD9); db(0xED); } +void fldpi() { db(0xD9); db(0xEB); } +void fldz() { db(0xD9); db(0xEE); } +void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); } +void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); } +void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); } +void fmulp() { db(0xDE); db(0xC9); } +void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); } +void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); } +void fninit() { db(0xDB); db(0xE3); } +void fnop() { db(0xD9); db(0xD0); } +void fpatan() { db(0xD9); db(0xF3); } +void fprem() { db(0xD9); db(0xF8); } +void fprem1() { db(0xD9); db(0xF5); } +void fptan() { db(0xD9); db(0xF2); } +void frndint() { db(0xD9); db(0xFC); } +void fscale() { db(0xD9); db(0xFD); } +void fsin() { db(0xD9); db(0xFE); } +void fsincos() { db(0xD9); db(0xFB); } +void fsqrt() { db(0xD9); db(0xFA); } +void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); } +void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); } +void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); } +void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); } +void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); } +void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); } +void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); } +void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); } +void fsubp() { db(0xDE); db(0xE9); } +void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); } +void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); } +void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); } +void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); } +void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); } +void fsubrp() { db(0xDE); db(0xE1); } +void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); } +void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); } +void ftst() { db(0xD9); db(0xE4); } +void fucom() { db(0xDD); db(0xE1); } +void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); } +void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); } +void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); } +void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); } +void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); } +void fucomp() { db(0xDD); db(0xE9); } +void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); } +void fucompp() { db(0xDA); db(0xE9); } +void fwait() { db(0x9B); } +void fxam() { db(0xD9); db(0xE5); } +void fxch() { db(0xD9); db(0xC9); } +void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); } +void fxtract() { db(0xD9); db(0xF4); } +void fyl2x() { db(0xD9); db(0xF1); } +void fyl2xp1() { db(0xD9); db(0xF9); } +void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); } +void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); } +void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); } +void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); } +void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); } +void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); } +void inc(const Operand& op) { opIncDec(op, 0x40, 0); } +void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524 +void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524 +void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524 +void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524 +void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524 +void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524 +void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524 +void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524 +void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524 +void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524 +void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524 +void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524 +void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524 +void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524 +void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524 +void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524 +void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524 +void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524 +void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524 +void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524 +void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524 +void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524 +void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524 +void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524 +void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524 +void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524 +void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524 +void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524 +void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524 +void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524 +void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524 +void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524 +void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524 +void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524 +void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void lahf() { db(0x9F); } +void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); } +void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); } +void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); } +void lfence() { db(0x0F); db(0xAE); db(0xE8); } +void lock() { db(0xF0); } +void lzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); } +void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); } +void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); } +void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); } +void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); } +void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); } +void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); } +void mfence() { db(0x0F); db(0xAE); db(0xF0); } +void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); } +void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); } +void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); } +void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); } +void monitor() { db(0x0F); db(0x01); db(0xC8); } +void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); } +void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); } +void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); } +void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); } +void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); } +void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); } +void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); } +void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); } +void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); } +void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); } +void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); } +void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); } +void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); } +void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); } +void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); } +void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); } +void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); } +void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); } +void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); } +void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); } +void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); } +void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); } +void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); } +void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); } +void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); } +void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); } +void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); } +void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); } +void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); } +void movsb() { db(0xA4); } +void movsd() { db(0xA5); } +void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); } +void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); } +void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); } +void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); } +void movsw() { db(0x66); db(0xA5); } +void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); } +void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); } +void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); } +void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); } +void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); } +void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); } +void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); } +void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); } +void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); } +void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); } +void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); } +void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); } +void mwait() { db(0x0F); db(0x01); db(0xC9); } +void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); } +void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); } +void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); } +void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); } +void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); } +void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); } +void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); } +void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); } +void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); } +void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); } +void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); } +void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); } +void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); } +void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); } +void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); } +void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); } +void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); } +void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); } +void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); } +void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); } +void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast(imm), 0x3a); } +void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); } +void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); } +void pause() { db(0xF3); db(0x90); } +void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); } +void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); } +void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); } +void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); } +void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); } +void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); } +void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); } +void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); } +void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); } +void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); } +void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); } +void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); } +void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); } +void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); } +void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); } +void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); } +void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); } +void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); } +void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); } +void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); } +void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); } +void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); } +void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); } +void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); } +void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); } +void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); } +void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); } +void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); } +void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); } +void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); } +void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); } +void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); } +void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); } +void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); } +void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); } +void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); } +void popcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); } +void popf() { db(0x9D); } +void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); } +void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); } +void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); } +void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); } +void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); } +void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); } +void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); } +void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); } +void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); } +void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); } +void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); } +void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); } +void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); } +void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); } +void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); } +void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); } +void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); } +void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); } +void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); } +void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); } +void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); } +void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); } +void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); } +void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); } +void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); } +void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); } +void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); } +void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); } +void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); } +void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); } +void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); } +void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); } +void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); } +void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); } +void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); } +void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); } +void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); } +void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); } +void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); } +void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); } +void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); } +void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); } +void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); } +void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); } +void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); } +void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); } +void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); } +void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); } +void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); } +void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); } +void pushf() { db(0x9C); } +void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); } +void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); } +void rcl(const Operand& op, int imm) { opShift(op, imm, 2); } +void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); } +void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); } +void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); } +void rcr(const Operand& op, int imm) { opShift(op, imm, 3); } +void rdmsr() { db(0x0F); db(0x32); } +void rdpmc() { db(0x0F); db(0x33); } +void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdtsc() { db(0x0F); db(0x31); } +void rdtscp() { db(0x0F); db(0x01); db(0xF9); } +void rep() { db(0xF3); } +void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } } +void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); } +void rol(const Operand& op, int imm) { opShift(op, imm, 0); } +void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); } +void ror(const Operand& op, int imm) { opShift(op, imm, 1); } +void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); } +void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } +void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); } +void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); } +void sahf() { db(0x9E); } +void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void sal(const Operand& op, int imm) { opShift(op, imm, 4); } +void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); } +void sar(const Operand& op, int imm) { opShift(op, imm, 7); } +void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); } +void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); } +void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); } +void scasb() { db(0xAE); } +void scasd() { db(0xAF); } +void scasw() { db(0x66); db(0xAF); } +void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524 +void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524 +void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524 +void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524 +void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void sfence() { db(0x0F); db(0xAE); db(0xF8); } +void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); } +void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void shl(const Operand& op, int imm) { opShift(op, imm, 4); } +void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); } +void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); } +void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); } +void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); } +void shr(const Operand& op, int imm) { opShift(op, imm, 5); } +void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); } +void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); } +void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); } +void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); } +void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); } +void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); } +void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); } +void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); } +void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); } +void stac() { db(0x0F); db(0x01); db(0xCB); } +void stc() { db(0xF9); } +void std() { db(0xFD); } +void sti() { db(0xFB); } +void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); } +void stosb() { db(0xAA); } +void stosd() { db(0xAB); } +void stosw() { db(0x66); db(0xAB); } +void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); } +void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); } +void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); } +void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); } +void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); } +void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); } +void tzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); } +void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); } +void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); } +void ud2() { db(0x0F); db(0x0B); } +void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); } +void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); } +void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); } +void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); } +void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); } +void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); } +void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); } +void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); } +void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); } +void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); } +void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); } +void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); } +void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); } +void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); } +void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); } +void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); } +void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); } +void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); } +void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); } +void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); } +void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); } +void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); } +void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); } +void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); } +void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); } +void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); } +void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); } +void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); } +void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); } +void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); } +void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); } +void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); } +void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); } +void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); } +void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); } +void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); } +void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); } +void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); } +void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); } +void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); } +void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); } +void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); } +void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); } +void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); } +void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); } +void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); } +void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); } +void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); } +void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); } +void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); } +void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); } +void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); } +void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); } +void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); } +void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); } +void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); } +void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); } +void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); } +void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); } +void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); } +void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); } +void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); } +void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); } +void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); } +void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); } +void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); } +void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); } +void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); } +void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); } +void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); } +void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); } +void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); } +void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); } +void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); } +void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); } +void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); } +void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); } +void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); } +void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); } +void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); } +void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); } +void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); } +void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); } +void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); } +void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); } +void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); } +void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); } +void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); } +void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); } +void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); } +void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); } +void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); } +void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); } +void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); } +void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); } +void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); } +void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); } +void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); } +void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); } +void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); } +void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); } +void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); } +void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); } +void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); } +void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); } +void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); } +void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); } +void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); } +void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); } +void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); } +void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); } +void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); } +void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); } +void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); } +void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); } +void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); } +void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); } +void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); } +void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); } +void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); } +void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); } +void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); } +void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); } +void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); } +void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); } +void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); } +void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); } +void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); } +void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); } +void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); } +void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); } +void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); } +void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); } +void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); } +void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); } +void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); } +void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); } +void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); } +void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); } +void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); } +void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); } +void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); } +void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); } +void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); } +void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); } +void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); } +void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); } +void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); } +void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); } +void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); } +void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); } +void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); } +void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); } +void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); } +void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); } +void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); } +void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); } +void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); } +void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); } +void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); } +void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); } +void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); } +void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); } +void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); } +void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); } +void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); } +void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); } +void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); } +void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); } +void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); } +void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); } +void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); } +void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); } +void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); } +void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); } +void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); } +void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); } +void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); } +void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); } +void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); } +void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); } +void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); } +void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); } +void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); } +void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); } +void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); } +void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); } +void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); } +void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); } +void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); } +void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); } +void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); } +void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); } +void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); } +void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); } +void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); } +void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); } +void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); } +void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); } +void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); } +void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); } +void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); } +void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); } +void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); } +void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); } +void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); } +void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); } +void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); } +void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); } +void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); } +void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); } +void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); } +void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); } +void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); } +void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); } +void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); } +void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); } +void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); } +void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); } +void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); } +void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); } +void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); } +void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); } +void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); } +void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); } +void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); } +void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); } +void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); } +void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); } +void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); } +void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); } +void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); } +void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); } +void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); } +void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); } +void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); } +void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); } +void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); } +void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); } +void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); } +void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); } +void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); } +void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); } +void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); } +void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); } +void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); } +void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); } +void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); } +void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); } +void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); } +void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); } +void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); } +void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); } +void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); } +void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); } +void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); } +void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); } +void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); } +void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); } +void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); } +void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); } +void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); } +void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); } +void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); } +void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); } +void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); } +void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); } +void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); } +void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); } +void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); } +void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); } +void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); } +void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); } +void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); } +void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); } +void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); } +void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); } +void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); } +void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); } +void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); } +void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); } +void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); } +void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); } +void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); } +void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); } +void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); } +void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); } +void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); } +void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); } +void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); } +void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); } +void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); } +void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); } +void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); } +void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); } +void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); } +void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); } +void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); } +void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); } +void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); } +void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); } +void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); } +void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); } +void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); } +void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); } +void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); } +void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); } +void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); } +void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); } +void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); } +void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); } +void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); } +void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); } +void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); } +void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); } +void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); } +void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); } +void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); } +void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); } +void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); } +void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); } +void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); } +void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); } +void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); } +void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); } +void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); } +void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); } +void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); } +void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); } +void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); } +void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); } +void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); } +void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); } +void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); } +void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); } +void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); } +void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); } +void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); } +void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); } +void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); } +void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); } +void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); } +void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); } +void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); } +void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); } +void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); } +void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); } +void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); } +void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); } +void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); } +void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); } +void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); } +void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); } +void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); } +void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } } +void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); } +void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); } +void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); } +void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); } +void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); } +void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); } +void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); } +void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); } +void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); } +void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); } +void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); } +void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); } +void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); } +void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); } +void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); } +void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); } +void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); } +void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); } +void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); } +void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); } +void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); } +void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); } +void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); } +void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); } +void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); } +void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); } +void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); } +void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); } +void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); } +void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); } +void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); } +void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); } +void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); } +void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); } +void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); } +void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); } +void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); } +void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); } +void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); } +void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); } +void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); } +void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); } +void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); } +void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); } +void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); } +void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); } +void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); } +void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); } +void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); } +void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); } +void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); } +void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); } +void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); } +void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); } +void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); } +void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); } +void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); } +void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); } +void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); } +void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); } +void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); } +void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); } +void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); } +void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); } +void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); } +void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); } +void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); } +void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); } +void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); } +void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); } +void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); } +void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); } +void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); } +void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); } +void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); } +void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); } +void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); } +void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); } +void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); } +void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); } +void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); } +void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); } +void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); } +void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); } +void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); } +void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); } +void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); } +void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); } +void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); } +void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); } +void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); } +void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); } +void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); } +void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); } +void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); } +void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); } +void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); } +void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); } +void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); } +void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); } +void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); } +void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); } +void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); } +void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); } +void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); } +void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); } +void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); } +void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); } +void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); } +void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); } +void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); } +void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); } +void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); } +void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); } +void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); } +void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); } +void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); } +void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); } +void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); } +void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); } +void vzeroall() { db(0xC5); db(0xFC); db(0x77); } +void vzeroupper() { db(0xC5); db(0xF8); db(0x77); } +void wait() { db(0x9B); } +void wbinvd() { db(0x0F); db(0x09); } +void wrmsr() { db(0x0F); db(0x30); } +void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); } +void xgetbv() { db(0x0F); db(0x01); db(0xD0); } +void xlatb() { db(0xD7); } +void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); } +void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); } +void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); } +void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); } +#ifdef XBYAK_ENABLE_OMITTED_OPERAND +void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); } +void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); } +void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); } +void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); } +void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); } +void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); } +void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); } +void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); } +void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); } +void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); } +void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); } +void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); } +void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); } +void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); } +void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); } +void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); } +void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); } +void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); } +void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); } +void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); } +void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); } +void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); } +void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); } +void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); } +void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); } +void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); } +void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); } +void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); } +void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); } +void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); } +void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); } +void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); } +void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); } +void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); } +void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); } +void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); } +void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); } +void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); } +void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); } +void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); } +void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); } +void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); } +void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); } +void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); } +void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); } +void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); } +void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); } +void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); } +void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); } +void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); } +void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); } +void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); } +void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); } +void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); } +void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); } +void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); } +void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); } +void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); } +void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); } +void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); } +void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); } +void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); } +void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); } +void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); } +void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); } +void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); } +void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); } +void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); } +void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); } +void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); } +void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); } +void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); } +void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); } +void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); } +void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); } +void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); } +void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); } +void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); } +void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); } +void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); } +void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); } +void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); } +void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); } +void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); } +void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); } +void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); } +void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); } +void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); } +void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); } +void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); } +void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); } +void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); } +void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); } +void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); } +void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); } +void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); } +void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); } +void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); } +void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); } +void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); } +void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); } +void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); } +void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); } +void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); } +void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); } +void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); } +void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); } +void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); } +void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); } +void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); } +void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); } +void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); } +void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); } +void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); } +void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); } +void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); } +void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); } +void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); } +void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); } +void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); } +void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); } +void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); } +void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); } +void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); } +void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); } +void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); } +void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); } +void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); } +void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); } +void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); } +void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); } +void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); } +void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); } +void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); } +void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); } +void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); } +void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); } +void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); } +void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); } +void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); } +void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); } +void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); } +void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); } +void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); } +void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); } +void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); } +void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); } +void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); } +void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); } +void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); } +void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); } +void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); } +void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); } +void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); } +void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); } +void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); } +void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); } +void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); } +void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); } +void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); } +void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); } +void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); } +void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); } +void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); } +void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); } +void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); } +void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); } +void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); } +void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); } +void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); } +void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); } +void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); } +void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); } +void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); } +void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); } +void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); } +void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); } +void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); } +void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); } +void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); } +void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); } +void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); } +void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); } +void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); } +void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); } +void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); } +void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); } +void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); } +void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); } +void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); } +void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); } +void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); } +void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); } +void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); } +void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); } +void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); } +void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); } +void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); } +void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); } +void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); } +void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); } +void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); } +void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); } +void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); } +void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); } +void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); } +void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); } +void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); } +void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); } +void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); } +void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); } +void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); } +void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); } +void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); } +void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); } +void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); } +void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); } +void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); } +void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); } +void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); } +void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); } +void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); } +void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); } +void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); } +void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); } +void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); } +void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); } +void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); } +void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); } +void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); } +void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); } +void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); } +void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); } +void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); } +void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); } +void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); } +void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); } +void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); } +void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); } +void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); } +void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); } +void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); } +void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); } +void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); } +void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); } +void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); } +void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); } +void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); } +void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); } +void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); } +void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); } +void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); } +void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); } +void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); } +void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); } +void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); } +#endif +#ifdef XBYAK64 +void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void cdqe() { db(0x48); db(0x98); } +void cqo() { db(0x48); db(0x99); } +void cmpsq() { db(0x48); db(0xA7); } +void movsq() { db(0x48); db(0xA5); } +void scasq() { db(0x48); db(0xAF); } +void stosq() { db(0x48); db(0xAB); } +void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); } +void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); } +void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); } +void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); } +void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); } +void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); } +void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); } +void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); } +void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); } +void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); } +#else +void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void aaa() { db(0x37); } +void aad() { db(0xD5); db(0x0A); } +void aam() { db(0xD4); db(0x0A); } +void aas() { db(0x3F); } +void daa() { db(0x27); } +void das() { db(0x2F); } +void popad() { db(0x61); } +void popfd() { db(0x9D); } +void pusha() { db(0x60); } +void pushad() { db(0x60); } +void pushfd() { db(0x9C); } +void popa() { db(0x61); } +#endif +#ifndef XBYAK_NO_OP_NAMES +void and(const Operand& op1, const Operand& op2) { and_(op1, op2); } +void and(const Operand& op, uint32 imm) { and_(op, imm); } +void or(const Operand& op1, const Operand& op2) { or_(op1, op2); } +void or(const Operand& op, uint32 imm) { or_(op, imm); } +void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); } +void xor(const Operand& op, uint32 imm) { xor_(op, imm); } +void not(const Operand& op) { not_(op); } +#endif +#ifndef XBYAK_DISABLE_AVX512 +void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); } +void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); } +void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); } +void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); } +void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); } +void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); } +void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); } +void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); } +void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); } +void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); } +void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); } +void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); } +void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); } +void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); } +void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); } +void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); } +void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); } +void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); } +void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); } +void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); } +void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); } +void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); } +void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); } +void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); } +void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); } +void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); } +void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); } +void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); } +void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); } +void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); } +void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); } +void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); } +void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); } +void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); } +void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); } +void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); } +void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); } +void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); } +void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); } +void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); } +void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); } +void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); } +void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); } +void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); } +void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); } +void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); } +void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); } +void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); } +void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); } +void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); } +void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); } +void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); } +void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); } +void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); } +void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); } +void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); } +void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); } +void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); } +void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); } +void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); } +void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); } +void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); } +void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); } +void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); } +void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); } +void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); } +void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); } +void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); } +void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); } +void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); } +void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); } +void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); } +void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); } +void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); } +void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); } +void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); } +void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); } +void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); } +void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); } +void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); } +void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); } +void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); } +void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); } +void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); } +void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); } +void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); } +void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); } +void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); } +void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); } +void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); } +void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); } +void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); } +void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); } +void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); } +void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); } +void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); } +void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); } +void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); } +void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); } +void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); } +void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); } +void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); } +void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); } +void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); } +void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); } +void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); } +void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); } +void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); } +void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); } +void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); } +void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); } +void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); } +void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); } +void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); } +void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); } +void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); } +void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); } +void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); } +void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); } +void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); } +void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); } +void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); } +void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); } +void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); } +void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); } +void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); } +void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); } +void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); } +void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); } +void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); } +void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); } +void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); } +void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); } +void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); } +void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); } +void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); } +void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); } +void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); } +void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); } +void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); } +void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); } +void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); } +void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); } +void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); } +void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); } +void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); } +void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); } +void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); } +void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); } +void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); } +void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); } +void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); } +void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); } +void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); } +void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); } +void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); } +void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); } +void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); } +void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); } +void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); } +void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); } +void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); } +void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); } +void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); } +void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); } +void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); } +void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); } +void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); } +void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); } +void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); } +void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); } +void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); } +void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); } +void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); } +void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); } +void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); } +void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); } +void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); } +void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); } +void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); } +void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); } +void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); } +void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); } +void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); } +void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); } +void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); } +void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); } +void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); } +void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); } +void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); } +void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); } +void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); } +void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); } +void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); } +void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); } +void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); } +void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); } +void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); } +void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); } +void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); } +void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); } +void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); } +void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); } +void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); } +void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); } +void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); } +void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); } +void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); } +void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); } +void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); } +void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); } +void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); } +void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); } +void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); } +void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); } +void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); } +void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); } +void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); } +void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); } +void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); } +void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); } +void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); } +void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); } +void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); } +void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); } +void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); } +void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); } +void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); } +void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); } +void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); } +void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); } +void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); } +void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); } +void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); } +void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); } +void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); } +void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); } +void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); } +void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); } +void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); } +void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); } +void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); } +void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); } +void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); } +void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); } +void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); } +void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); } +void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); } +void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); } +void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); } +void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); } +void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); } +void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); } +void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); } +void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); } +void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); } +#ifdef XBYAK64 +void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); } +void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); } +void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); } +#endif +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h new file mode 100644 index 0000000000..8ef076e680 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h @@ -0,0 +1,772 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#ifndef XBYAK_XBYAK_UTIL_H_ +#define XBYAK_XBYAK_UTIL_H_ + +/** + utility class and functions for Xbyak + Xbyak::util::Clock ; rdtsc timer + Xbyak::util::Cpu ; detect CPU + @note this header is UNDER CONSTRUCTION! +*/ +#include "xbyak.h" + +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + #define XBYAK_INTEL_CPU_SPECIFIC +#endif + +#ifdef XBYAK_INTEL_CPU_SPECIFIC +#ifdef _MSC_VER + #if (_MSC_VER < 1400) && defined(XBYAK32) + static inline __declspec(naked) void __cpuid(int[4], int) + { + __asm { + push ebx + push esi + mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn + cpuid + mov esi, dword ptr [esp + 4 * 2 + 4] // data + mov dword ptr [esi], eax + mov dword ptr [esi + 4], ebx + mov dword ptr [esi + 8], ecx + mov dword ptr [esi + 12], edx + pop esi + pop ebx + ret + } + } + #else + #include // for __cpuid + #endif +#else + #ifndef __GNUC_PREREQ + #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor))) + #endif + #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__) + #include + #else + #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm' + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #else + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #endif + #endif +#endif +#endif + +namespace Xbyak { namespace util { + +typedef enum { + SmtLevel = 1, + CoreLevel = 2 +} IntelCpuTopologyLevel; + +/** + CPU detection class +*/ +class Cpu { + uint64 type_; + //system topology + bool x2APIC_supported_; + static const size_t maxTopologyLevels = 2; + unsigned int numCores_[maxTopologyLevels]; + + static const unsigned int maxNumberCacheLevels = 10; + unsigned int dataCacheSize_[maxNumberCacheLevels]; + unsigned int coresSharignDataCache_[maxNumberCacheLevels]; + unsigned int dataCacheLevels_; + + unsigned int get32bitAsBE(const char *x) const + { + return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); + } + unsigned int mask(int n) const + { + return (1U << n) - 1; + } + void setFamily() + { + unsigned int data[4] = {}; + getCpuid(1, data); + stepping = data[0] & mask(4); + model = (data[0] >> 4) & mask(4); + family = (data[0] >> 8) & mask(4); + // type = (data[0] >> 12) & mask(2); + extModel = (data[0] >> 16) & mask(4); + extFamily = (data[0] >> 20) & mask(8); + if (family == 0x0f) { + displayFamily = family + extFamily; + } else { + displayFamily = family; + } + if (family == 6 || family == 0x0f) { + displayModel = (extModel << 4) + model; + } else { + displayModel = model; + } + } + unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end) + { + return (val >> base) & ((1u << (end - base)) - 1); + } + void setNumCores() + { + if ((type_ & tINTEL) == 0) return; + + unsigned int data[4] = {}; + + /* CAUTION: These numbers are configuration as shipped by Intel. */ + getCpuidEx(0x0, 0, data); + if (data[0] >= 0xB) { + /* + if leaf 11 exists(x2APIC is supported), + we use it to get the number of smt cores and cores on socket + + leaf 0xB can be zeroed-out by a hypervisor + */ + x2APIC_supported_ = true; + for (unsigned int i = 0; i < maxTopologyLevels; i++) { + getCpuidEx(0xB, i, data); + IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15); + if (level == SmtLevel || level == CoreLevel) { + numCores_[level - 1] = extractBit(data[1], 0, 15); + } + } + } else { + /* + Failed to deremine num of cores without x2APIC support. + TODO: USE initial APIC ID to determine ncores. + */ + numCores_[SmtLevel - 1] = 0; + numCores_[CoreLevel - 1] = 0; + } + + } + void setCacheHierarchy() + { + if ((type_ & tINTEL) == 0) return; + const unsigned int NO_CACHE = 0; + const unsigned int DATA_CACHE = 1; +// const unsigned int INSTRUCTION_CACHE = 2; + const unsigned int UNIFIED_CACHE = 3; + unsigned int smt_width = 0; + unsigned int logical_cores = 0; + unsigned int data[4] = {}; + + if (x2APIC_supported_) { + smt_width = numCores_[0]; + logical_cores = numCores_[1]; + } + + /* + Assumptions: + the first level of data cache is not shared (which is the + case for every existing architecture) and use this to + determine the SMT width for arch not supporting leaf 11. + when leaf 4 reports a number of core less than numCores_ + on socket reported by leaf 11, then it is a correct number + of cores not an upperbound. + */ + for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) { + getCpuidEx(0x4, i, data); + unsigned int cacheType = extractBit(data[0], 0, 4); + if (cacheType == NO_CACHE) break; + if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) { + unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1; + if (logical_cores != 0) { // true only if leaf 0xB is supported and valid + actual_logical_cores = (std::min)(actual_logical_cores, logical_cores); + } + assert(actual_logical_cores != 0); + dataCacheSize_[dataCacheLevels_] = + (extractBit(data[1], 22, 31) + 1) + * (extractBit(data[1], 12, 21) + 1) + * (extractBit(data[1], 0, 11) + 1) + * (data[2] + 1); + if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores; + assert(smt_width != 0); + // FIXME: check and fix number of cores sharing L3 cache for different configurations + // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket) + coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u); + dataCacheLevels_++; + } + } + } + +public: + int model; + int family; + int stepping; + int extModel; + int extFamily; + int displayFamily; // family + extFamily + int displayModel; // model + extModel + + unsigned int getNumCores(IntelCpuTopologyLevel level) { + if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER); + if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED); + return (level == CoreLevel) + ? numCores_[level - 1] / numCores_[SmtLevel - 1] + : numCores_[level - 1]; + } + + unsigned int getDataCacheLevels() const { return dataCacheLevels_; } + unsigned int getCoresSharingDataCache(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return coresSharignDataCache_[i]; + } + unsigned int getDataCacheSize(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return dataCacheSize_[i]; + } + + /* + data[] = { eax, ebx, ecx, edx } + */ + static inline void getCpuid(unsigned int eaxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuid(reinterpret_cast(data), eaxIn); + #else + __cpuid(eaxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)data; +#endif + } + static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuidex(reinterpret_cast(data), eaxIn, ecxIn); + #else + __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)ecxIn; + (void)data; +#endif + } + static inline uint64 getXfeature() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return _xgetbv(0); + #else + unsigned int eax, edx; + // xgetvb is not support on gcc 4.2 +// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0)); + __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0)); + return ((uint64)edx << 32) | eax; + #endif +#else + return 0; +#endif + } + typedef uint64 Type; + + static const Type NONE = 0; + static const Type tMMX = 1 << 0; + static const Type tMMX2 = 1 << 1; + static const Type tCMOV = 1 << 2; + static const Type tSSE = 1 << 3; + static const Type tSSE2 = 1 << 4; + static const Type tSSE3 = 1 << 5; + static const Type tSSSE3 = 1 << 6; + static const Type tSSE41 = 1 << 7; + static const Type tSSE42 = 1 << 8; + static const Type tPOPCNT = 1 << 9; + static const Type tAESNI = 1 << 10; + static const Type tSSE5 = 1 << 11; + static const Type tOSXSAVE = 1 << 12; + static const Type tPCLMULQDQ = 1 << 13; + static const Type tAVX = 1 << 14; + static const Type tFMA = 1 << 15; + + static const Type t3DN = 1 << 16; + static const Type tE3DN = 1 << 17; + static const Type tSSE4a = 1 << 18; + static const Type tRDTSCP = 1 << 19; + static const Type tAVX2 = 1 << 20; + static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt + static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx + static const Type tLZCNT = 1 << 23; + + static const Type tINTEL = 1 << 24; + static const Type tAMD = 1 << 25; + + static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb + static const Type tRDRAND = 1 << 27; + static const Type tADX = 1 << 28; // adcx, adox + static const Type tRDSEED = 1 << 29; // rdseed + static const Type tSMAP = 1 << 30; // stac + static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest + static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort + static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph + static const Type tMOVBE = uint64(1) << 34; // mobve + static const Type tAVX512F = uint64(1) << 35; + static const Type tAVX512DQ = uint64(1) << 36; + static const Type tAVX512_IFMA = uint64(1) << 37; + static const Type tAVX512IFMA = tAVX512_IFMA; + static const Type tAVX512PF = uint64(1) << 38; + static const Type tAVX512ER = uint64(1) << 39; + static const Type tAVX512CD = uint64(1) << 40; + static const Type tAVX512BW = uint64(1) << 41; + static const Type tAVX512VL = uint64(1) << 42; + static const Type tAVX512_VBMI = uint64(1) << 43; + static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual + static const Type tAVX512_4VNNIW = uint64(1) << 44; + static const Type tAVX512_4FMAPS = uint64(1) << 45; + static const Type tPREFETCHWT1 = uint64(1) << 46; + static const Type tPREFETCHW = uint64(1) << 47; + static const Type tSHA = uint64(1) << 48; + static const Type tMPX = uint64(1) << 49; + static const Type tAVX512_VBMI2 = uint64(1) << 50; + static const Type tGFNI = uint64(1) << 51; + static const Type tVAES = uint64(1) << 52; + static const Type tVPCLMULQDQ = uint64(1) << 53; + static const Type tAVX512_VNNI = uint64(1) << 54; + static const Type tAVX512_BITALG = uint64(1) << 55; + static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56; + + Cpu() + : type_(NONE) + , x2APIC_supported_(false) + , numCores_() + , dataCacheSize_() + , coresSharignDataCache_() + , dataCacheLevels_(0) + { + unsigned int data[4] = {}; + const unsigned int& EAX = data[0]; + const unsigned int& EBX = data[1]; + const unsigned int& ECX = data[2]; + const unsigned int& EDX = data[3]; + getCpuid(0, data); + const unsigned int maxNum = EAX; + static const char intel[] = "ntel"; + static const char amd[] = "cAMD"; + if (ECX == get32bitAsBE(amd)) { + type_ |= tAMD; + getCpuid(0x80000001, data); + if (EDX & (1U << 31)) type_ |= t3DN; + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 30)) type_ |= tE3DN; + if (EDX & (1U << 22)) type_ |= tMMX2; + if (EDX & (1U << 27)) type_ |= tRDTSCP; + } + if (ECX == get32bitAsBE(intel)) { + type_ |= tINTEL; + getCpuid(0x80000001, data); + if (EDX & (1U << 27)) type_ |= tRDTSCP; + if (ECX & (1U << 5)) type_ |= tLZCNT; + if (ECX & (1U << 8)) type_ |= tPREFETCHW; + } + getCpuid(1, data); + if (ECX & (1U << 0)) type_ |= tSSE3; + if (ECX & (1U << 9)) type_ |= tSSSE3; + if (ECX & (1U << 19)) type_ |= tSSE41; + if (ECX & (1U << 20)) type_ |= tSSE42; + if (ECX & (1U << 22)) type_ |= tMOVBE; + if (ECX & (1U << 23)) type_ |= tPOPCNT; + if (ECX & (1U << 25)) type_ |= tAESNI; + if (ECX & (1U << 1)) type_ |= tPCLMULQDQ; + if (ECX & (1U << 27)) type_ |= tOSXSAVE; + if (ECX & (1U << 30)) type_ |= tRDRAND; + if (ECX & (1U << 29)) type_ |= tF16C; + + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 23)) type_ |= tMMX; + if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE; + if (EDX & (1U << 26)) type_ |= tSSE2; + + if (type_ & tOSXSAVE) { + // check XFEATURE_ENABLED_MASK[2:1] = '11b' + uint64 bv = getXfeature(); + if ((bv & 6) == 6) { + if (ECX & (1U << 28)) type_ |= tAVX; + if (ECX & (1U << 12)) type_ |= tFMA; + if (((bv >> 5) & 7) == 7) { + getCpuidEx(7, 0, data); + if (EBX & (1U << 16)) type_ |= tAVX512F; + if (type_ & tAVX512F) { + if (EBX & (1U << 17)) type_ |= tAVX512DQ; + if (EBX & (1U << 21)) type_ |= tAVX512_IFMA; + if (EBX & (1U << 26)) type_ |= tAVX512PF; + if (EBX & (1U << 27)) type_ |= tAVX512ER; + if (EBX & (1U << 28)) type_ |= tAVX512CD; + if (EBX & (1U << 30)) type_ |= tAVX512BW; + if (EBX & (1U << 31)) type_ |= tAVX512VL; + if (ECX & (1U << 1)) type_ |= tAVX512_VBMI; + if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2; + if (ECX & (1U << 8)) type_ |= tGFNI; + if (ECX & (1U << 9)) type_ |= tVAES; + if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ; + if (ECX & (1U << 11)) type_ |= tAVX512_VNNI; + if (ECX & (1U << 12)) type_ |= tAVX512_BITALG; + if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ; + if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW; + if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS; + } + } + } + } + if (maxNum >= 7) { + getCpuidEx(7, 0, data); + if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2; + if (EBX & (1U << 3)) type_ |= tBMI1; + if (EBX & (1U << 8)) type_ |= tBMI2; + if (EBX & (1U << 9)) type_ |= tENHANCED_REP; + if (EBX & (1U << 18)) type_ |= tRDSEED; + if (EBX & (1U << 19)) type_ |= tADX; + if (EBX & (1U << 20)) type_ |= tSMAP; + if (EBX & (1U << 4)) type_ |= tHLE; + if (EBX & (1U << 11)) type_ |= tRTM; + if (EBX & (1U << 14)) type_ |= tMPX; + if (EBX & (1U << 29)) type_ |= tSHA; + if (ECX & (1U << 0)) type_ |= tPREFETCHWT1; + } + setFamily(); + setNumCores(); + setCacheHierarchy(); + } + void putFamily() const + { + printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n", + family, model, stepping, extFamily, extModel); + printf("display:family=%X, model=%X\n", displayFamily, displayModel); + } + bool has(Type type) const + { + return (type & type_) != 0; + } +}; + +class Clock { +public: + static inline uint64 getRdtsc() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return __rdtsc(); + #else + unsigned int eax, edx; + __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx)); + return ((uint64)edx << 32) | eax; + #endif +#else + // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu + return 0; +#endif + } + Clock() + : clock_(0) + , count_(0) + { + } + void begin() + { + clock_ -= getRdtsc(); + } + void end() + { + clock_ += getRdtsc(); + count_++; + } + int getCount() const { return count_; } + uint64 getClock() const { return clock_; } + void clear() { count_ = 0; clock_ = 0; } +private: + uint64 clock_; + int count_; +}; + +#ifdef XBYAK64 +const int UseRCX = 1 << 6; +const int UseRDX = 1 << 7; + +class Pack { + static const size_t maxTblNum = 15; + const Xbyak::Reg64 *tbl_[maxTblNum]; + size_t n_; +public: + Pack() : tbl_(), n_(0) {} + Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); } + Pack(const Pack& rhs) + : n_(rhs.n_) + { + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + } + Pack& operator=(const Pack& rhs) + { + n_ = rhs.n_; + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + return *this; + } + Pack(const Xbyak::Reg64& t0) + { n_ = 1; tbl_[0] = &t0; } + Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; } + Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; } + Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; } + Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; } + Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; } + Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; } + Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; } + Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; } + Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; } + Pack& append(const Xbyak::Reg64& t) + { + if (n_ == maxTblNum) { + fprintf(stderr, "ERR Pack::can't append\n"); + throw Error(ERR_BAD_PARAMETER); + } + tbl_[n_++] = &t; + return *this; + } + void init(const Xbyak::Reg64 *tbl, size_t n) + { + if (n > maxTblNum) { + fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n); + throw Error(ERR_BAD_PARAMETER); + } + n_ = n; + for (size_t i = 0; i < n; i++) { + tbl_[i] = &tbl[i]; + } + } + const Xbyak::Reg64& operator[](size_t n) const + { + if (n >= n_) { + fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_); + throw Error(ERR_BAD_PARAMETER); + } + return *tbl_[n]; + } + size_t size() const { return n_; } + /* + get tbl[pos, pos + num) + */ + Pack sub(size_t pos, size_t num = size_t(-1)) const + { + if (num == size_t(-1)) num = n_ - pos; + if (pos + num > n_) { + fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num); + throw Error(ERR_BAD_PARAMETER); + } + Pack pack; + pack.n_ = num; + for (size_t i = 0; i < num; i++) { + pack.tbl_[i] = tbl_[pos + i]; + } + return pack; + } + void put() const + { + for (size_t i = 0; i < n_; i++) { + printf("%s ", tbl_[i]->toString()); + } + printf("\n"); + } +}; + +class StackFrame { +#ifdef XBYAK64_WIN + static const int noSaveNum = 6; + static const int rcxPos = 0; + static const int rdxPos = 1; +#else + static const int noSaveNum = 8; + static const int rcxPos = 3; + static const int rdxPos = 2; +#endif + static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax + Xbyak::CodeGenerator *code_; + int pNum_; + int tNum_; + bool useRcx_; + bool useRdx_; + int saveNum_; + int P_; + bool makeEpilog_; + Xbyak::Reg64 pTbl_[4]; + Xbyak::Reg64 tTbl_[maxRegNum]; + Pack p_; + Pack t_; + StackFrame(const StackFrame&); + void operator=(const StackFrame&); +public: + const Pack& p; + const Pack& t; + /* + make stack frame + @param sf [in] this + @param pNum [in] num of function parameter(0 <= pNum <= 4) + @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14 + @param stackSizeByte [in] local stack size + @param makeEpilog [in] automatically call close() if true + + you can use + rax + gp0, ..., gp(pNum - 1) + gt0, ..., gt(tNum-1) + rcx if tNum & UseRCX + rdx if tNum & UseRDX + rsp[0..stackSizeByte - 1] + */ + StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true) + : code_(code) + , pNum_(pNum) + , tNum_(tNum & ~(UseRCX | UseRDX)) + , useRcx_((tNum & UseRCX) != 0) + , useRdx_((tNum & UseRDX) != 0) + , saveNum_(0) + , P_(0) + , makeEpilog_(makeEpilog) + , p(p_) + , t(t_) + { + using namespace Xbyak; + if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM); + const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0); + if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM); + const Reg64& _rsp = code->rsp; + saveNum_ = (std::max)(0, allRegNum - noSaveNum); + const int *tbl = getOrderTbl() + noSaveNum; + for (int i = 0; i < saveNum_; i++) { + code->push(Reg64(tbl[i])); + } + P_ = (stackSizeByte + 7) / 8; + if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment + P_ *= 8; + if (P_ > 0) code->sub(_rsp, P_); + int pos = 0; + for (int i = 0; i < pNum; i++) { + pTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + for (int i = 0; i < tNum_; i++) { + tTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx); + if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx); + p_.init(pTbl_, pNum); + t_.init(tTbl_, tNum_); + } + /* + make epilog manually + @param callRet [in] call ret() if true + */ + void close(bool callRet = true) + { + using namespace Xbyak; + const Reg64& _rsp = code_->rsp; + const int *tbl = getOrderTbl() + noSaveNum; + if (P_ > 0) code_->add(_rsp, P_); + for (int i = 0; i < saveNum_; i++) { + code_->pop(Reg64(tbl[saveNum_ - 1 - i])); + } + + if (callRet) code_->ret(); + } + ~StackFrame() + { + if (!makeEpilog_) return; + try { + close(); + } catch (std::exception& e) { + printf("ERR:StackFrame %s\n", e.what()); + //exit(1); + } + } +private: + const int *getOrderTbl() const + { + using namespace Xbyak; + static const int tbl[] = { +#ifdef XBYAK64_WIN + Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI, +#else + Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, +#endif + Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15 + }; + return &tbl[0]; + } + int getRegIdx(int& pos) const + { + assert(pos < maxRegNum); + using namespace Xbyak; + const int *tbl = getOrderTbl(); + int r = tbl[pos++]; + if (useRcx_) { + if (r == Operand::RCX) { return Operand::R10; } + if (r == Operand::R10) { r = tbl[pos++]; } + } + if (useRdx_) { + if (r == Operand::RDX) { return Operand::R11; } + if (r == Operand::R11) { return tbl[pos++]; } + } + return r; + } +}; +#endif + +} } // end of util +#endif diff --git a/thirdparty/oidn/weights/rtlightmap_hdr.tza b/thirdparty/oidn/weights/rtlightmap_hdr.tza new file mode 100644 index 0000000000..12459a33bc Binary files /dev/null and b/thirdparty/oidn/weights/rtlightmap_hdr.tza differ -- cgit v1.2.3