summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp305
1 files changed, 305 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
new file mode 100644
index 0000000000..2c5a8e8973
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
@@ -0,0 +1,305 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_uni_lrn_kernel_f32.hpp"
+#include "jit_uni_lrn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+template <cpu_isa_t isa>
+jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd), ker_(nullptr)
+ , ker_first_(nullptr), ker_last_(nullptr)
+{
+ using namespace alg_kind;
+
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ float A = pd()->desc()->lrn_alpha / ls;
+ float K = pd()->desc()->lrn_k;
+
+ auto pk = pd()->desc()->prop_kind;
+ auto ak = pd()->desc()->alg_kind;
+ auto dat_tag = pd()->dat_tag_;
+
+ if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_across(H, W, 0), A, K, pk);
+ ker_first_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_across(H, W, -1), A, K, pk);
+ ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_across(H, W, +1), A, K, pk);
+ } else if (dat_tag == nChw8c && ak == lrn_within_channel) {
+ /* within channel, local_size (x) local_size */
+ A /= ls; /* XXX: why? */
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_within(H, W, ls), A, K, pk);
+ } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw_across(C, H*W, 0), A, K, pk);
+ int remind = (H*W) % VECTOR_LENGTH;
+ if (remind != 0) {
+ ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw_across(C, H*W, remind), A, K, pk);
+ }
+ } else if (true /* XXX: why */) {
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(nhwc_across(C), A, K, pk);
+ }
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t()
+{ delete ker_; delete ker_first_; delete ker_last_; }
+
+template <cpu_isa_t isa>
+void jit_uni_lrn_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE);
+
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int HW = pd()->H() * pd()->W();
+ const int ls = pd()->desc()->local_size;
+
+ auto ak = pd()->desc()->alg_kind;
+ auto dat_tag = pd()->dat_tag_;
+
+ if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
+ parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ if (c8 == 0)
+ (*ker_first_)(&args);
+ else if (c8 == C / VECTOR_LENGTH - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+ else if (dat_tag == nChw8c && ak == lrn_within_channel) {
+ parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ (*ker_)(&args);
+ });
+ }
+ else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
+ parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH,
+ [&](int n, int hw8) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH];
+ args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH];
+ args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH];
+ if ((hw8 + 1)*VECTOR_LENGTH > HW)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+ else { // nhwc
+ parallel_nd(N, HW, [&](int n, int hw) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + hw * C];
+ args.dst = &dst[n*HW*C + hw * C];
+ args.scratch = &ws[n*HW*C + hw * C];
+ (*ker_)(&args);
+ });
+ }
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() {
+ using namespace prop_kind;
+ using namespace alg_kind;
+
+ const memory_desc_wrapper data_d(src_md());
+ bool ok = true
+ && mayiuse(isa)
+ && is_fwd()
+ && everyone_is(data_type::f32, data_d.data_type())
+ && !has_zero_dim_memory()
+ && data_d.ndims() == 4
+ && data_d.dims()[1] % VECTOR_LENGTH == 0
+ && data_d.dims()[1] >= 2 * VECTOR_LENGTH
+ && desc()->lrn_beta == 0.75
+ && attr()->has_default_values();
+ if (!ok) return unimplemented;
+
+ if (desc_.prop_kind == forward_training) ws_md_ = *src_md();
+
+ dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc);
+
+ bool args_ok_across = true
+ && desc()->alg_kind == lrn_across_channels
+ && desc()->local_size == 5
+ && one_of(dat_tag_, nChw8c, nchw, nhwc);
+
+ const int jit_max_local_size = 5; // bigger size triggers too big code size
+ bool args_ok_within = true
+ && desc()->alg_kind == lrn_within_channel
+ && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE
+ ? jit_max_local_size : MAX_LOCAL_SIZE)
+ && data_d.dims()[2] >= desc()->local_size
+ && data_d.dims()[3] >= desc()->local_size
+ && one_of(dat_tag_, nChw8c);
+
+ return args_ok_across || args_ok_within ? success : unimplemented;
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
+{
+ using namespace alg_kind;
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ float A = pd()->desc()->lrn_alpha / ls;
+ float B = pd()->desc()->lrn_beta;
+
+ int use_h_parallelizm = 0;// XXX
+ if (C / VECTOR_LENGTH == 1) {
+ ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, 3), A, B, use_h_parallelizm);
+ }
+ else {
+ ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, 0), A, B, use_h_parallelizm);
+ ker_first_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, -1), A, B, use_h_parallelizm);
+ ker_last_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, +1), A, B, use_h_parallelizm);
+ }
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t()
+{
+ delete ker_; delete ker_first_; delete ker_last_;
+}
+
+template <cpu_isa_t isa>
+void jit_uni_lrn_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+
+ int use_h_parallelizm = 0; // XXX
+ if (use_h_parallelizm) {
+ parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) {
+ auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH
+ + h*W*VECTOR_LENGTH;
+ jit_args_bwd_t args;
+ args.src = &src[offset];
+ args.diff_dst = &diff_dst[offset];
+ args.scratch = &ws[offset];
+ args.diff_src = &diff_src[offset];
+ if (C / VECTOR_LENGTH == 1)
+ (*ker_)(&args);
+ else if (c8 == 0)
+ (*ker_first_)(&args);
+ else if (c8 == C / VECTOR_LENGTH - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+ else {
+ parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
+ auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH;
+ jit_args_bwd_t args;
+ args.src = &src[offset];
+ args.diff_dst = &diff_dst[offset];
+ args.scratch = &ws[offset];
+ args.diff_src = &diff_src[offset];
+ if (C / VECTOR_LENGTH == 1)
+ (*ker_)(&args);
+ else if (c8 == 0)
+ (*ker_first_)(&args);
+ else if (c8 == C / VECTOR_LENGTH - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_lrn_bwd_t<isa>::pd_t::init() {
+ using namespace prop_kind;
+ using namespace alg_kind;
+
+ const memory_desc_wrapper data_d(src_md());
+ bool ok = true
+ && mayiuse(isa)
+ && !is_fwd()
+ && utils::everyone_is(data_type::f32, data_d.data_type())
+ && !has_zero_dim_memory()
+ && data_d.ndims() == 4
+ && data_d.dims()[1] % VECTOR_LENGTH == 0
+ && desc()->lrn_beta == 0.75
+ && attr()->has_default_values();
+ if (!ok) return unimplemented;
+
+ ws_md_ = *src_md();
+ if (!compare_ws(hint_fwd_pd_)) return unimplemented;
+
+ dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c);
+
+ bool args_ok_across = true
+ && desc()->alg_kind == lrn_across_channels
+ && desc()->local_size == 5
+ && utils::one_of(dat_tag_, nChw8c);
+
+ return args_ok_across ? success : unimplemented;
+}
+
+template struct jit_uni_lrn_fwd_t<sse42>;
+template struct jit_uni_lrn_fwd_t<avx2>;
+template struct jit_uni_lrn_bwd_t<avx2>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s