summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/cpu/rnn
diff options
context:
space:
mode:
authorJuan Linietsky <reduzio@gmail.com>2020-05-01 09:34:23 -0300
committerJuan Linietsky <reduzio@gmail.com>2020-05-10 15:59:09 -0300
commit1bea8e1eacc68bcedbd3f207395bccf11011dae2 (patch)
treeb75303a69491978c1e13360a3e6f355c5234dfe0 /thirdparty/oidn/mkl-dnn/src/cpu/rnn
parent6a0473bcc23c096ef9ee929632a209761c2668f6 (diff)
New lightmapper
-Added LocalVector (needed it) -Added stb_rect_pack (It's pretty cool, we could probably use it for other stuff too) -Fixes and changes all around the place -Added library for 128 bits fixed point (required for Delaunay3D)
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/rnn')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp90
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp180
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp170
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp143
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp113
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp191
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp401
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp788
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp328
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp380
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp426
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp225
12 files changed, 3435 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp
new file mode 100644
index 0000000000..537084db91
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp
@@ -0,0 +1,90 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Common for RNN and LSTM cell execution
+ */
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+using namespace rnn_utils;
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_cell_execution_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execution)) {
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
+ rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
+ states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
+ rnn.gates_ws_ld);
+ }
+ (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
+ rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
+
+ if (rnn_postgemm_ != nullptr)
+ rnn_postgemm_->execute<src_data_t, acc_data_t>(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+ else
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+}
+template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
+template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
+
+template <>
+rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) {
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+
+ /// bwd by data on the cell
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld,
+ 0.0, diff_states_t_l_, rnn.states_ws_ld);
+
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
+ rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
+ rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
+ &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
+
+ /// bwd by weights on the cell
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
+ diff_w_layer_, rnn.diff_weights_layer_ld);
+ }
+
+ if (!rnn.merge_gemm_iter)
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0,
+ diff_w_iter_, rnn.diff_weights_iter_ld);
+
+ /// bwd by bias we just accumulate diffs from the gates
+ gates_reduction(rnn, ws_gates_, diff_bias_);
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp
new file mode 100644
index 0000000000..e1a61d4c62
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp
@@ -0,0 +1,180 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution GRU
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+
+#define AOC array_offset_calculator
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_[0]);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+
+ // 1. gemm Wx[0-2],x
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
+ rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
+ states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
+ rnn.gates_ws_ld);
+ }
+
+ // 2. gemm Wh[0-1],h
+ (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb,
+ rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
+ rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
+
+ // 3. activation zt and rt + elemwise multiplication rt,ht-1
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
+ ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
+ states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j);
+ }
+ });
+
+ // 4. gemm Wh[2],h~t
+ (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1],
+ rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0,
+ &(ws_gates(0, 2, 0)), rnn.gates_ws_ld);
+
+ // 5. activation h~t + calculate ht
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
+ states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
+ + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
+ }
+ });
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) {
+ assert(!"GRU int8 is not supported");
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+
+ // use state memory for intermediate computations
+ // TODO: use cell ws for that
+ float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0));
+ float *hG1_ = dhG1_;
+ AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld);
+ AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld);
+
+ // 1. calculate dG2, dG1, and part of dht-1
+ // dG2^ = dh * (1 - G0) * (1 - G2^2)
+ // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
+ // dht-1 (part) = dh * G0
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float h = states_tm1_l(i, j);
+ float dHt = diff_states_tp1_l(0, i, j)
+ + diff_states_t_lp1(rnn.n_states, i, j);
+ float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
+ * one_m_square(ws_gates(i, 2, j));
+ float dG0 = (h - ws_gates(i, 2, j)) * dHt
+ * x_m_square(ws_gates(i, 0, j));
+
+ diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
+ ws_gates(i, 0, j) = dG0;
+ ws_gates(i, 2, j) = dG2;
+ }
+ });
+
+ // 2. calculate intermediate d(hG1)
+ // d(hG1) = dG2 * W2h^t
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1],
+ rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0,
+ dhG1_, rnn.states_ws_ld);
+
+ // 3. calculate dG1^ and part of dht-1
+ // dG1^ = d(hG1) * h * G1 * (1 - G1)
+ // dht-1 (part) += d(hG1) * G1
+ // h * G1 (required for dWh)
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float h = states_tm1_l(i, j);
+ float G1 = ws_gates(i, 1, j);
+ diff_states_t_l(0, i, j) += dhG1(i, j) * G1;
+ ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1);
+ hG1(i, j) = G1 * h;
+ }
+ });
+
+ // 4. calculate diff weights
+ // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h)
+ gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
+ rnn.diff_weights_iter_ld);
+ gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)),
+ rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0,
+ &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld);
+
+ // 5. calculate diff states
+ // dht-1 += dG1 * W1h + dG0 * W0h
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
+ (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0],
+ rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0,
+ diff_states_t_l_, rnn.states_ws_ld);
+
+ if (!rnn.merge_gemm_layer) {
+ // dWx += [dG0 dG1 dG2] * [x]
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
+ diff_w_layer_, rnn.diff_weights_layer_ld);
+ // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
+ rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
+ rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
+ &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld);
+ }
+
+ // 6. calculate diff bias
+ gates_reduction(rnn, ws_gates_, diff_bias_);
+}
+#undef AOC
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp
new file mode 100644
index 0000000000..8dea8c90a4
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp
@@ -0,0 +1,170 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution GRU with linear before reset
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+#define AOC array_offset_calculator
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_);
+ AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j);
+ ws_gates(i, 0, j) = logistic_fwd(
+ ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j));
+ ws_gates(i, 1, j) = logistic_fwd(
+ ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j));
+ ws_gates(i, 2, j) = tanh_fwd(
+ ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j));
+ states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
+ + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
+ if (rnn.is_training)
+ ws_Wh_b(i, j) = Wh_b;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) {
+ assert(!"GRU LBR int8 is not supported");
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) {
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
+ rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
+ states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
+ rnn.gates_ws_ld);
+ }
+ (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
+ rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld);
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) {
+ assert(!"GRU LBR int8 is not supported");
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+ ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
+ AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
+
+ // 1. calculate dG1 dG2 dG3
+ // dG0 = (dht - G2) * dht * (1 - G0) * G0
+ // dG1 = (W*h + b) * dG2 * (1 - G1) * G1
+ // dG2 = (1 - G0) * dht * (1 - G2*G2)
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float h = states_tm1_l(i, j);
+ float dHt = diff_states_tp1_l(0, i, j)
+ + diff_states_t_lp1(rnn.n_states, i, j);
+ float dG0 = (h - ws_gates(i, 2, j)) * dHt
+ * x_m_square(ws_gates(i, 0, j));
+ float dG2 = (1.0f - ws_gates(i, 0, j))
+ * one_m_square(ws_gates(i, 2, j)) * dHt;
+ float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j));
+
+ diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
+ ws_gates(i, 2, j) = dG2;
+ ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j);
+ ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0;
+ ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1;
+ }
+ });
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) {
+ ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+
+ if (!rnn.merge_gemm_layer) {
+ // dx = dG * Wx^t
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
+ rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
+ rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
+ &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
+ // dWx += dG^t * x
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
+ diff_w_layer_, rnn.diff_weights_layer_ld);
+ }
+ // dh += dGr * Wh^t
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld,
+ 1.0, diff_states_t_l_, rnn.states_ws_ld);
+
+ // dWh += dGr^t * h
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_,
+ rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
+ rnn.diff_weights_layer_ld);
+
+ // db1-3 += e * dG
+ // db4 += e * (r * dG2)
+ gates_reduction(rnn, ws_gates_, diff_bias_);
+
+ parallel_nd(rnn.dic, [&](int j) {
+ for (int i = 0; i < rnn.mb; i++) {
+ diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j);
+ }
+ });
+}
+
+#undef AOC
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp
new file mode 100644
index 0000000000..a15ba00d4c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp
@@ -0,0 +1,143 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution LSTM
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "../simple_q10n.hpp"
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
+ ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
+ ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
+ ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j));
+
+ float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j)
+ + ws_gates(i, 0, j) * ws_gates(i, 2, j);
+ states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp);
+ c_states_t_l(i, j) = tmp;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) {
+ ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_u8_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+
+ float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+
+ auto q_d = [&](float f) {
+ float qf = f * data_scale + data_shift;
+ return qz_a1b0<float, src_data_t>()(qf);
+ };
+
+ auto deq_w = [&](acc_data_t s, int gate, int j) {
+ return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ?
+ saturate<float>(s) * (1.f / (weights_scales[0] * data_scale)) :
+ saturate<float>(s) * (1.f / (weights_scales[gate * rnn.dic + j]
+ * data_scale));
+ };
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float G0 = logistic_fwd<float>(
+ deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j));
+ float G1 = logistic_fwd<float>(
+ deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j));
+ float G2 = tanh_fwd<float>(
+ deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j));
+ float G3 = logistic_fwd<float>(
+ deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j));
+ float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2;
+ states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp));
+ c_states_t_l(i, j) = tmp;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float Ct = c_states_t_l(i, j);
+ /// @todo save it in the workspace in fwd pass or recompute it to
+ /// save bw
+ float tanhCt = tanh_fwd(Ct);
+ // we have 2 incoming diffs on Ht
+ float dHt = diff_states_tp1_l(0, i, j)
+ + diff_states_t_lp1(rnn.n_states, i, j);
+ float dCt = diff_states_tp1_l(1, i, j)
+ + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt;
+
+ float dG1 = c_states_tm1_l(i, j) * dCt
+ * x_m_square(ws_gates(i, 1, j));
+ float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j));
+ float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j));
+ float dG2
+ = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j));
+
+ diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j);
+
+ ws_gates(i, 0, j) = dG0;
+ ws_gates(i, 1, j) = dG1;
+ ws_gates(i, 2, j) = dG2;
+ ws_gates(i, 3, j) = dG3;
+ }
+ });
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp
new file mode 100644
index 0000000000..4536e8dfad
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp
@@ -0,0 +1,113 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution of Vanilla RNN
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+
+template <>
+float activation<alg_kind::eltwise_relu, prop_kind::forward>(
+ float dd, float s, float alpha, float cliping) {
+ return relu_fwd<float>(s, alpha);
+}
+
+template <>
+float activation<alg_kind::eltwise_relu, prop_kind::backward>(
+ float dd, float s, float alpha, float cliping) {
+ return relu_bwd<float>(dd, s, alpha);
+}
+
+template <>
+float activation<alg_kind::eltwise_tanh, prop_kind::forward>(
+ float dd, float s, float alpha, float cliping) {
+ return tanh_fwd<float>(s);
+}
+
+template <>
+float activation<alg_kind::eltwise_tanh, prop_kind::backward>(
+ float dd, float s, float alpha, float cliping) {
+ return dd * one_m_square<float>(s);
+}
+
+template <>
+float activation<alg_kind::eltwise_logistic, prop_kind::forward>(
+ float dd, float s, float alpha, float cliping) {
+ return logistic_fwd<float>(s);
+}
+
+template <>
+float activation<alg_kind::eltwise_logistic, prop_kind::backward>(
+ float dd, float s, float alpha, float cliping) {
+ return dd * x_m_square<float>(s);
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ for (int j = 0; j < rnn.dic; j++) {
+ const float h
+ = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0);
+ ws_gates(i, 0, j) = states_t_l(i, j) = h;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) {
+ assert(!"VANILLA RNN int8 is not supported");
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ for (int j = 0; j < rnn.dic; ++j) {
+ const float dH = diff_states_t_lp1(rnn.n_states, i, j)
+ + diff_states_tp1_l(0, i, j);
+ auto g = ws_gates(i, 0, j);
+ ws_gates(i, 0, j) = activation_func(dH, g, 0, 0);
+ }
+ });
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp
new file mode 100644
index 0000000000..b39427caf9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp
@@ -0,0 +1,191 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_RNN_PD_HPP
+#define CPU_RNN_PD_HPP
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "rnn_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "rnn_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
+ using rnn_fwd_pd_t::rnn_fwd_pd_t;
+
+protected:
+ status_t set_default_params() {
+ using namespace format_tag;
+ if (src_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
+ if (dst_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
+
+ // Optional parameters
+ if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc));
+ if (with_bias() && bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
+ if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc));
+
+ return status::success;
+ }
+
+ status_t check_layout_consistency() {
+ using namespace format_tag;
+ using namespace data_type;
+ using namespace types;
+
+ auto is_blocked = [&](memory_desc_t md, int ndims) {
+ return md.format_kind == format_kind::blocked && md.ndims == ndims;
+ };
+
+ bool ok = true;
+ ok = ok && is_blocked(src_layer_md_, 3)
+ && is_blocked(dst_layer_md_, 3);
+ ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_),
+ is_blocked(src_iter_md_, 5))
+ && IMPLICATION(!is_zero_md(&dst_iter_md_),
+ is_blocked(dst_iter_md_, 5));
+
+ if (weights_layer_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldigo_p);
+ else
+ ok = ok && rnn_utils::is_ldigo(&weights_layer_md_);
+
+ if (weights_iter_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldigo_p);
+ else
+ ok = ok && rnn_utils::is_ldigo(&weights_iter_md_);
+
+ ok = ok && IMPLICATION(!is_zero_md(&bias_md_),
+ memory_desc_matches_tag(bias_md_, ldgo));
+
+ /* Int8 is supported only for packed weights */
+ data_type_t weights_iter_dt = weights_iter_md_.data_type;
+ data_type_t weights_layer_dt = weights_layer_md_.data_type;
+ ok = ok && IMPLICATION(
+ weights_iter_dt == s8, weights_iter_md_.format_kind
+ == format_kind::rnn_packed);
+ ok = ok && IMPLICATION(
+ weights_layer_dt == s8, weights_layer_md_.format_kind
+ == format_kind::rnn_packed);
+
+ return ok ? status::success : status::unimplemented;
+ }
+};
+
+struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
+ using rnn_bwd_pd_t::rnn_bwd_pd_t;
+
+protected:
+ status_t set_default_params() {
+ using namespace format_tag;
+ if (src_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
+ if (dst_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
+
+ if (diff_src_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc));
+ if (diff_weights_layer_md_.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo));
+ CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo));
+ }
+ if (diff_weights_iter_md_.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo));
+ CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo));
+ }
+ if (diff_dst_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc));
+
+ // Optional parameters
+ if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc));
+ if (with_bias() && bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
+ if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc));
+
+ if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc));
+ if (with_bias() && diff_bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo));
+ if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc));
+
+ return status::success;
+ }
+
+ status_t check_layout_consistency() {
+ using namespace format_tag;
+ using namespace types;
+
+ auto is_blocked = [&](memory_desc_t md, int ndims) {
+ return md.format_kind == format_kind::blocked && md.ndims == ndims;
+ };
+
+ bool ok = true;
+ ok = ok && is_blocked(src_layer_md_, 3)
+ && is_blocked(dst_layer_md_, 3);
+ ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_),
+ is_blocked(src_iter_md_, 5))
+ && IMPLICATION(!is_zero_md(&dst_iter_md_),
+ is_blocked(dst_iter_md_, 5));
+
+ if (weights_layer_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldgoi_p);
+ else
+ ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_);
+
+ if (weights_iter_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldgoi_p);
+ else
+ ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_);
+
+ ok = ok && IMPLICATION(!is_zero_md(&bias_md_),
+ memory_desc_matches_tag(bias_md_, ldgo));
+
+ ok = ok && is_blocked(diff_src_layer_md_, 3)
+ && is_blocked(diff_dst_layer_md_, 3);
+ ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_),
+ is_blocked(diff_src_iter_md_, 5))
+ && IMPLICATION(!is_zero_md(&diff_dst_iter_md_),
+ is_blocked(diff_dst_iter_md_, 5));
+
+ ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_)
+ && rnn_utils::is_ldigo(&diff_weights_iter_md_);
+ ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_),
+ memory_desc_matches_tag(diff_bias_md_, ldgo));
+
+ return ok ? status::success : status::unimplemented;
+ }
+};
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp
new file mode 100644
index 0000000000..09445648aa
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp
@@ -0,0 +1,401 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution LSTM
+ */
+
+#include "rnn_utils.hpp"
+#include "../jit_generator.hpp"
+#include "../jit_uni_eltwise.hpp"
+#include "c_types_map.hpp"
+#include "utils.hpp"
+
+#include "mkldnn_thread.hpp"
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_uni_rnn_postgemm_kernel : public jit_generator {
+
+ typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_,
+ void *c_states_t_l_, void *c_states_tm1_l_);
+
+ jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){}
+
+ virtual void init() = 0;
+
+template <typename src_data_t, typename acc_data_t>
+ rnn_elemwise_sig(execute) {
+ rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_);
+ rnn_utils::bias_aoc_t bias(rnn, bias_);
+ rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_);
+ rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+
+ // Todo: add parallelization on dic for the batch 1 case
+ // Assumption: the kernel runs a loop on dic elements
+ parallel_nd(rnn.mb, [&](int i) {
+ auto b_ = &bias(0, 0);
+ auto g_ = &ws_gates(i, 0, 0);
+ auto s_tl_ = &states_t_l(i, 0);
+ auto c_tl_ = &c_states_t_l(i, 0);
+ auto c_tm1l_ = &c_states_tm1_l(i, 0);
+ kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_);
+ });
+ }
+
+protected:
+ kernel_t kernel_;
+ const rnn_utils::rnn_conf_t &rnn_;
+ const primitive_attr_t *attr_;
+};
+
+template <cpu_isa_t isa, impl::data_type_t src_data_t>
+struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel
+{
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd)
+
+ typedef typename utils::conditional<src_data_t == data_type::u8, int32_t,
+ float>::type acc_data_t;
+ typedef typename utils::conditional<isa == avx512_core,
+ jit_uni_eltwise_injector_f32<avx512_common>,
+ jit_uni_eltwise_injector_f32<isa>>::type injector_t;
+
+ jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr)
+ : jit_uni_rnn_postgemm_kernel(rnn, attr){}
+
+ void init() override {
+ // we use rax for both constant tables as they use the same table
+ sigmoid_injector_ = new injector_t(this,
+ alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax);
+ tanh_injector_ = new injector_t(this,
+ alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax);
+ generate();
+ kernel_ = (kernel_t) this->getCode();
+ }
+
+protected:
+ injector_t *sigmoid_injector_;
+ injector_t *tanh_injector_;
+
+ // register size in bytes
+ using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
+ size_t vlen = cpu_isa_traits<isa>::vlen;
+ size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen;
+ size_t cstate_dt_size = sizeof(float);
+ size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float);
+ size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float);
+ size_t qscale_dt_size = sizeof(float);
+ size_t bias_dt_size = sizeof(float);
+
+ void generate() {
+ using namespace Xbyak;
+
+ int mask = attr_->rnn_weights_qparams_.mask_;
+ float *weights_scales = attr_->rnn_weights_qparams_.scales_;
+ float data_scale = attr_->rnn_data_qparams_.scale_;
+ float data_shift = attr_->rnn_data_qparams_.shift_;
+
+ // Labels declaration
+ Label vector_loop_start_label, vector_loop_end_label;
+ Label rem_loop_start_label, rem_loop_end_label;
+ Label table_label;
+
+ // Register map
+ Reg64 loop_cnt(r11); // loop counter
+ Reg64 table_reg(rbx); // table is used for data scale and shifts
+ Reg64 weights_scales_reg(r13);
+ // We skip vmm0 as it can be used by the injector for masks on sse4.2
+ Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7);
+
+ // constant table map
+ Address dscale_off_addr = ptr[table_reg];
+ Address dshift_off_addr = ptr[table_reg + vlen];
+ Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen];
+ Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen];
+
+ // quantize from float to u8
+ auto q_d = [&](Vmm f, Vmm tmp_vmm) {
+ uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm);
+ uni_vmulps(f, f, dscale_off_addr); // apply scale
+ uni_vaddps(f, f, dshift_off_addr); // apply shift
+ uni_vcvtps2dq(f, f); // convert to int32
+ uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16
+ uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation
+ // Note that the results are interleaved by 128 bit chunks, so we need to merge them together
+ switch (vlen) {
+ case 64: { //avx512
+ Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx());
+ uni_vmovups(tmpz, zmm_perm_mask_addr);
+ vpermd(fz, tmpz, fz);
+ break; }
+ case 32: { //avx
+ Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx());
+ uni_vmovups(tmpy, ymm_perm_mask_addr);
+ vpermd(fy, tmpy, fy);
+ break; }
+ case 16: // sse: nothing to do
+ break;
+ default: assert(!"Unsupported case");
+ };
+ };
+
+ auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) {
+ if (packed)
+ uni_vrcpps(tmp, s);
+ else
+ uni_vrcpss(tmp, s); // prevent divide by zero
+ // we add one Newton iteration
+ uni_vmulps(s, s, tmp);
+ uni_vmulps(s, s, tmp); // s <- s * tmp^2
+ uni_vaddps(tmp, tmp, tmp);
+ uni_vsubps(tmp, tmp, s);
+ uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2
+ };
+
+ // dequantize from s32 to float
+ auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) {
+ // TODO: if mask is 0 precompute mul and inverse
+ if (mask == 0)
+ uni_vbroadcastss(tmp1, ptr[weights_scales_reg]);
+ else
+ uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]);
+ uni_vcvtdq2ps(s, s);
+ uni_vmulps(tmp1, tmp1, dscale_off_addr);
+ fast_recip(tmp1, tmp2, packed);
+ uni_vmulps(s, s, tmp1);
+ };
+
+ // We start code generations here
+ preamble();
+
+ // extract addresses passed as parameter
+#ifdef _WIN32
+ auto addr_ws_gates_reg = abi_param1;
+ auto addr_bias_reg = abi_param2;
+ auto addr_states_t_l_reg = abi_param3;
+ auto addr_c_states_tm1_l_reg = abi_param4;
+ auto addr_c_states_t_l_reg = r10;
+ // Here we cannot use rbp to have initial stack pointer so we
+ // use rsp and offset it with the size of pushed registers in
+ // preamble
+ mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]);
+#else
+ auto addr_ws_gates_reg = abi_param1;
+ auto addr_bias_reg = abi_param2;
+ auto addr_states_t_l_reg = abi_param3;
+ auto addr_c_states_tm1_l_reg = abi_param4;
+ auto addr_c_states_t_l_reg = abi_param5;
+#endif
+
+ // initialize registers with addresses and constants
+ mov(table_reg, table_label);
+ mov(weights_scales_reg, size_t(weights_scales));
+ // both sigmoid and tanh use the same table so load address just once in rax
+ sigmoid_injector_->load_table_addr();
+
+ mov(loop_cnt, rnn_.dic * gate_dt_size);
+ cmp(loop_cnt, vlen);
+ jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
+
+ L(vector_loop_start_label);
+ {
+ // load G0 G1 G2 G3
+ uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
+ uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
+ uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
+ uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
+
+ // dequantize the gates from s32 to f32 if needed
+ if (src_data_t == data_type::u8){
+ deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true);
+ deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true);
+ deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true);
+ deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true);
+ }
+
+ // add biases
+ uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
+
+ // inject eltwise code
+ sigmoid_injector_->compute_vector(G0.getIdx());
+ sigmoid_injector_->compute_vector(G1.getIdx());
+ tanh_injector_->compute_vector(G2.getIdx());
+ sigmoid_injector_->compute_vector(G3.getIdx());
+
+ // compute c_states_t_l = G1 * c_tm1_l + G0 * G2
+ uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]);
+ uni_vmulps(tmp1_vmm, tmp1_vmm, G1);
+ uni_vfmadd231ps(tmp1_vmm, G0, G2);
+ uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm);
+
+ // states_t_l = G3 * tanh(c_states_t_l)
+ tanh_injector_->compute_vector(tmp1_vmm.getIdx());
+ uni_vmulps(tmp1_vmm, tmp1_vmm, G3);
+
+ // if int8, we quantize the resulting state
+ if (src_data_t == data_type::u8)
+ q_d(tmp1_vmm, tmp2_vmm);
+
+ // write back the result
+ if(vlen_dst == vlen)
+ uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm);
+ else
+ // we write only 1/4 of the register
+ switch(vlen_dst){
+ case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
+ case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
+ case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
+ default:
+ assert(!"Unsuported vector length for quantization");
+ }
+
+ // increment address pointers
+ add(addr_ws_gates_reg, vlen);
+ add(addr_bias_reg, vlen);
+ add(addr_states_t_l_reg, vlen_dst);
+ add(addr_c_states_tm1_l_reg, vlen);
+ add(addr_c_states_t_l_reg, vlen);
+ if (mask != 0)
+ add(weights_scales_reg, vlen);
+
+ // increment loop counter
+ sub(loop_cnt, vlen);
+ cmp(loop_cnt, vlen);
+ jge(vector_loop_start_label);
+ }
+ L(vector_loop_end_label);
+
+ cmp(loop_cnt, 0);
+ je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
+ // Same code as above, we just use movuss for accessing inputs
+ // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
+ L(rem_loop_start_label);
+ {
+ // remaping registers to Xmms
+ Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx());
+ Xmm tmp1s_vmm(tmp1_vmm.getIdx());
+
+ // load G0 G1 G2 G3
+ uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
+ uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
+ uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
+ uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
+
+ // dequantize the gates from s32 to f32 if needed
+ if (src_data_t == data_type::u8){
+ deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false);
+ deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false);
+ deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false);
+ deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false);
+ }
+
+ // add biases
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G0s, G0s, tmp1s_vmm);
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G1s, G1s, tmp1s_vmm);
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G2s, G2s, tmp1s_vmm);
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G3s, G3s, tmp1s_vmm);
+
+ // inject eltwise code
+ sigmoid_injector_->compute_vector(G0s.getIdx());
+ sigmoid_injector_->compute_vector(G1s.getIdx());
+ tanh_injector_->compute_vector(G2s.getIdx());
+ sigmoid_injector_->compute_vector(G3s.getIdx());
+
+ // compute c_states_t_l = G1 * c_tm1_l + G0s * G2
+ uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]);
+ uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s);
+ uni_vfmadd231ps(tmp1s_vmm, G0s, G2s);
+ uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm);
+
+ // states_t_l = G3 * tanh(c_states_t_l)
+ tanh_injector_->compute_vector(tmp1s_vmm.getIdx());
+ uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s);
+
+ // if int8, we quantize the resulting state
+ if (src_data_t == data_type::u8)
+ q_d(tmp1_vmm, tmp2_vmm);
+
+ // write back the result
+ if(vlen_dst == vlen)
+ uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm);
+ else
+ // we write only 1/4 of the register
+ switch(vlen_dst){
+ case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
+ case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
+ case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
+ default:
+ assert(!"Unsuported vector length for quantization");
+ }
+
+ // increment address pointers
+ add(addr_ws_gates_reg, gate_dt_size);
+ add(addr_bias_reg, bias_dt_size);
+ add(addr_states_t_l_reg, hstate_dt_size);
+ add(addr_c_states_tm1_l_reg, cstate_dt_size);
+ add(addr_c_states_t_l_reg, cstate_dt_size);
+ if (mask != 0)
+ add(weights_scales_reg, qscale_dt_size);
+
+ // increment loop counter
+ sub(loop_cnt, gate_dt_size);
+ cmp(loop_cnt, 0);
+ jg(rem_loop_start_label);
+
+ }
+ L(rem_loop_end_label);
+
+ postamble();
+
+ // Again, only one table is needed and shared between sigmoid and tanh
+ sigmoid_injector_->prepare_table(false);
+ tanh_injector_->prepare_table(true);
+
+ L(table_label);
+ {
+ for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale));
+ for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift));
+ // perm mask for ymm
+ dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7);
+ // perm mask for zmm
+ dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7);
+ dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14);
+ }
+ }
+
+};
+
+template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>;
+
+template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>;
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
new file mode 100644
index 0000000000..ead536816c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
@@ -0,0 +1,788 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ General architecture
+
+ for diff states, we have n_states + 1 as we have n_states diff
+ to propagate to the previous iteration and 1 states to propagate
+ to the previous layer
+ index 0 is dh for cell(t-1, l) to consume
+ index 1 is dc for cell(t-1, l) to consume
+ index 2 is dh for cell(t, l-1) to consume
+ this indexing enables to have the same indexing for states in elemwise
+ function
+ only the cell execution function should be impacted
+
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+#include "../gemm/gemm.hpp"
+#include "../simple_q10n.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace rnn_utils;
+#define AOC array_offset_calculator
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction(
+ const rnn_conf_t &rnn, const acc_data_t *ws_gates_,
+ float *diff_bias_) const {
+ auto body = [&](int i, int k) {
+ for (int j = 0; j < rnn.mb; j++)
+ diff_bias_[i * rnn.dic + k]
+ += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k];
+ };
+
+ // @todo block k on simd-width
+#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \
+ /* icc 17.0 has a problem with simd collapse */ \
+ && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
+#pragma omp parallel for simd collapse(2)
+ for (int i = 0; i < rnn.n_gates; i++)
+ for (int k = 0; k < rnn.dic; k++)
+ body(i, k);
+#else
+ parallel_nd(rnn.n_gates, rnn.dic, body);
+#endif
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) {
+ assert(ldA * ldB * ldC != 0);
+ extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB,
+ &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm);
+}
+
+template <>
+rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) {
+ assert(!"non packed gemm is disabled for int8");
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) {
+#if (USE_MKL_PACKED_GEMM)
+ assert(transA == 'N');
+ cblas_sgemm_compute(CblasColMajor, CblasPacked,
+ (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_,
+ ldB, beta, c_, ldC);
+#else
+ UNUSED(transA);
+ UNUSED(transB);
+ UNUSED(m);
+ UNUSED(n);
+ UNUSED(k);
+ UNUSED(alpha);
+ UNUSED(ldA);
+ UNUSED(b_);
+ UNUSED(ldB);
+ UNUSED(beta);
+ UNUSED(c_);
+ UNUSED(ldC);
+ assert(!"packed gemm is disabled");
+#endif
+}
+
+template <>
+rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) {
+#if (USE_MKL_PACKED_GEMM)
+ int8_t offseta = 0, offsetb = 0;
+ int32_t offsetc = 0;
+ cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked,
+ CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_,
+ ldB, offsetb, beta, c_, ldC, &offsetc);
+#else
+ UNUSED(transA);
+ UNUSED(transB);
+ UNUSED(m);
+ UNUSED(n);
+ UNUSED(k);
+ UNUSED(alpha);
+ UNUSED(ldA);
+ UNUSED(b_);
+ UNUSED(ldB);
+ UNUSED(beta);
+ UNUSED(c_);
+ UNUSED(ldC);
+ assert(!"packed gemm is disabled");
+#endif
+}
+
+//*************** Grid computations strategy: linear ***************//
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_grid_execution_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) {
+ AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
+ AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
+ AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
+ (rnn.n_states + 1), rnn.n_iter + 1,
+ rnn.states_nld * rnn.states_ws_ld);
+ AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
+ rnn.gates_nld * rnn.gates_ws_ld);
+ AOC<weights_data_t *, 3> weights_input(
+ weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
+ AOC<weights_data_t *, 3> weights_states(
+ weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
+ AOC<float*, 3> bias(
+ bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
+ AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer,
+ rnn.n_dir,
+ rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
+ AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir,
+ rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
+ AOC<float, 3> diff_bias(
+ diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
+ AOC<float, 4> ws_grid(
+ ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
+
+ // We run the grid of computation
+ for (int dir = 0; dir < rnn.n_dir; dir++) {
+ for (int j = 0; j < rnn.n_layer; j++) {
+ int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
+
+ if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic,
+ rnn.mb * rnn.n_iter, rnn.slc, 1.0,
+ weights_input(lay, dir, 0), rnn.weights_iter_ld,
+ &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0,
+ &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld);
+ }
+
+ for (int i = 0; i < rnn.n_iter; i++) {
+ int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
+ (this->*cell_func)(rnn,
+ &(ws_states(lay + 1, dir, iter + 1, 0)),
+ &(ws_c_states(lay + 1, dir, iter + 1, 0)),
+ &(ws_diff_states(lay, dir, 0, iter, 0)),
+ &(weights_input(lay, dir, 0)),
+ &(weights_states(lay, dir, 0)),
+ &(bias(lay, dir, 0)),
+ &(ws_states(lay, dir, iter + 1, 0)),
+ &(ws_states(lay + 1, dir, iter, 0)),
+ &(ws_c_states(lay + 1, dir, iter, 0)),
+ &(ws_diff_states(lay + 1, dir, 0, iter, 0)),
+ &(ws_diff_states(lay, dir, 0, iter + 1, 0)),
+ &(diff_weights_layer(lay, dir, 0)),
+ &(diff_weights_iter(lay, dir, 0)),
+ &(diff_bias(lay, dir, 0)),
+ &(ws_gates(lay, dir, iter, 0)),
+ &(ws_grid(lay, dir, iter, 0)),
+ ws_cell_);
+ }
+
+ if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
+ rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0),
+ rnn.weights_layer_ld,
+ (src_data_t *)(&(ws_gates(lay, dir, 0, 0))),
+ rnn.gates_ws_ld, 0.0,
+ (acc_data_t *)(&(ws_diff_states(
+ lay, dir, rnn.n_states, 0, 0))),
+ rnn.states_ws_ld);
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc,
+ rnn.mb * rnn.n_iter, 1.0,
+ (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
+ rnn.gates_ws_ld,
+ (src_data_t *)(&(ws_states(lay, dir, 1, 0))),
+ rnn.states_ws_ld, 1.0,
+ (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))),
+ rnn.diff_weights_layer_ld);
+ }
+ if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic,
+ rnn.mb * rnn.n_iter, 1.0,
+ (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
+ rnn.gates_ws_ld,
+ (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))),
+ rnn.states_ws_ld, 1.0,
+ (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))),
+ rnn.diff_weights_iter_ld);
+ }
+ }
+ }
+}
+
+//********* GRID computations strategy: utility functions **********//
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer(
+ const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
+ float *__restrict ws_diff_states_, const src_data_t *__restrict xt_,
+ const float *__restrict diff_dst_layer_) const {
+
+ AOC<src_data_t, 4> ws_states(
+ ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ auto xt_d = memory_desc_wrapper(pd()->src_md(0));
+
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto xxt = xt_ + xt_d.blk_off(it, b);
+ src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0));
+ src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
+ if (rnn.exec_dir != r2l)
+ for (int c = 0; c < rnn.slc; c++)
+ ws_l2r_ptr[c] = xxt[c];
+ if (rnn.exec_dir != l2r)
+ for (int c = 0; c < rnn.slc; c++)
+ ws_r2l_ptr[c] = xxt[c];
+ });
+}
+
+template <>
+void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_,
+ const float *diff_dst_layer_) const {
+ AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
+ (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0));
+
+ switch (rnn.exec_dir) {
+ case bi_concat:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ ws_diff_states(
+ rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
+ = diff_dst_layer_x[rnn.dic + s];
+ }
+ });
+ break;
+ case bi_sum:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ ws_diff_states(
+ rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
+ = diff_dst_layer_x[s];
+ }
+ });
+ break;
+ case l2r:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ }
+ });
+ break;
+ case r2l:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x = diff_dst_layer_
+ + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ }
+ });
+ break;
+ default: assert(!"Unsupported direction"); break;
+ }
+}
+
+/* For int8 configuration, input iteration states may be of types f32 or u8
+ * Internally h_state is always stored in u8 and c_state is always stored in f32
+ * If input states are of type u8 then h state is copied and c state is dequantized
+ * If input states are of type f32 then h state is quantized and c_state is copied
+ * */
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+template <typename input_data_t>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter(
+ const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
+ float *__restrict ws_c_states_, float *__restrict ws_diff_states_,
+ const input_data_t *__restrict firstit_states_,
+ const float *__restrict diff_dst_iter_) const {
+ AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+
+ const bool quantize = pd()->with_src_iter()
+ && pd()->src_md(1)->data_type == data_type::f32
+ && rnn.dt_conf != all_f32;
+ auto maybe_q = [&](input_data_t f) {
+ if (quantize) {
+ float qf = f * data_scale + data_shift;
+ return qz_a1b0<float, src_data_t>()(qf);
+ } else
+ return (src_data_t)f;
+ };
+
+ const bool dequantize = pd()->with_src_iter()
+ && pd()->src_md(1)->data_type == data_type::u8;
+ auto maybe_deq = [&](input_data_t s) {
+ if (dequantize)
+ return (((float)s - data_shift) / data_scale);
+ else
+ return (float)s;
+ };
+ auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1));
+ if (firstit_states_) {
+ parallel_nd(
+ rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
+ for (int s = 0; s < rnn.sic; s++)
+ ws_states(lay + 1, dir, 0, b, s) = maybe_q(
+ firstit_states_[firstit_states_d.blk_off(
+ lay, dir, 0, b, s)]);
+ if (pd()->cell_kind() == alg_kind::vanilla_lstm)
+ for (int s = 0; s < rnn.sic; s++)
+ ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq(
+ firstit_states_[firstit_states_d.blk_off(
+ lay, dir, 1, b, s)]);
+ });
+ } else {
+ parallel_nd(
+ rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
+ for (int j = 0; j < rnn.sic; j++) {
+ ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0;
+ ws_c_states(lay + 1, dir, 0, b, j) = 0.0f;
+ }
+ });
+ }
+}
+
+template <>
+template <typename input_data_t>
+void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_,
+ const input_data_t *firstit_states_,
+ const float *diff_dst_iter_) const {
+ AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1));
+ if (diff_dst_iter_) {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
+ [&](int lay, int dir, int state, int b) {
+ array_copy(&(ws_diff_states(
+ lay, dir, state, rnn.n_iter, b, 0)),
+ diff_dst_iter_
+ + diff_dst_iter_d.blk_off(
+ lay, dir, state, b),
+ rnn.dic);
+ });
+ } else {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
+ [&](int lay, int dir, int state, int i) {
+ for (int j = 0; j < rnn.dic; j++)
+ ws_diff_states(lay, dir, state, rnn.n_iter, i, j)
+ = 0.0f;
+ });
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+template <typename dst_data_t>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer(
+ const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer,
+ const src_data_t *ws_states_, const float *ws_diff_states_) const {
+
+ auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0));
+ AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ float shift = (pd()->attr()->rnn_data_qparams_.shift_);
+ float scale = (pd()->attr()->rnn_data_qparams_.scale_);
+
+ const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32
+ && rnn.dt_conf != all_f32;
+ auto maybe_deq = [&](src_data_t s) {
+ if (dequantize)
+ return (dst_data_t)(((float)s - shift) / scale);
+ else
+ return (dst_data_t)s;
+ };
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ int dir = 0;
+ if (rnn.exec_dir != r2l) {
+ for (int s = 0; s < rnn.dic; s++) {
+ dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
+ = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s));
+ }
+ dir = 1;
+ }
+ if (rnn.exec_dir != l2r) {
+ for (int s = 0; s < rnn.dic; s++)
+ switch (rnn.exec_dir) {
+ case bi_sum:
+ dst_layer_[dst_layer_d.blk_off(it, b, s)]
+ += maybe_deq(ws_states(
+ rnn.n_layer, dir, rnn.n_iter - it, b, s));
+ break;
+ default:
+ dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
+ = maybe_deq(ws_states(
+ rnn.n_layer, dir, rnn.n_iter - it, b, s));
+ }
+ }
+ });
+}
+
+template <>
+template <typename dst_data_t>
+void ref_rnn_bwd_f32_t::copy_res_layer(
+ const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_,
+ const src_data_t *ws_states_, const float *ws_diff_states_) const {
+ auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0));
+ AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
+ rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
+ rnn.states_ws_ld);
+
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ int dir = 0;
+ for (int s = 0; s < rnn.slc; s++) {
+ float *dst_addr = diff_src_layer_
+ + diff_src_layer_d.blk_off(
+ (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it,
+ b, dir * rnn.slc + s);
+ float res = ws_diff_states(0, 0, rnn.n_states, it, b, s);
+ if (rnn.n_dir - 1)
+ res += ws_diff_states(
+ 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s);
+ dst_addr[0] = res;
+ }
+ });
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+template <typename output_data_t>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter(
+ const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
+ const src_data_t *ws_states_, float *ws_c_states_,
+ const float *ws_diff_states_) const {
+ auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1));
+ AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+
+ const bool quantize = pd()->with_dst_iter()
+ && pd()->dst_md(1)->data_type == data_type::u8
+ && rnn.dt_conf != all_f32;
+ auto maybe_q = [&](float f) {
+ if (quantize) {
+ float qf = f * data_scale + data_shift;
+ return qz_a1b0<float, output_data_t>()(qf);
+ } else
+ return (output_data_t)f;
+ };
+
+ const bool dequantize = pd()->with_dst_iter()
+ && pd()->dst_md(1)->data_type == data_type::f32
+ && rnn.dt_conf != all_f32;
+ auto maybe_deq = [&](src_data_t s) {
+ if (dequantize)
+ return (output_data_t)(((float)s - data_shift) / data_scale);
+ else
+ return (output_data_t)s;
+ };
+ if (dst_iter_) {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
+ [&](int lay, int dir, int b) {
+ for (int s = 0; s < rnn.dic; s++) {
+ dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)]
+ = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s));
+ }
+ if (pd()->cell_kind() == alg_kind::vanilla_lstm)
+ for (int s = 0; s < rnn.dic; s++) {
+ dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)]
+ = maybe_q(ws_c_states(
+ lay + 1, dir, rnn.n_iter, b, s));
+ }
+ });
+ }
+}
+
+template <>
+template <typename output_data_t>
+void ref_rnn_bwd_f32_t::copy_res_iter(
+ const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
+ const src_data_t *ws_states_, float *ws_c_states_,
+ const float *ws_diff_states_) const {
+ auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1));
+ AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
+ rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
+ rnn.states_ws_ld);
+ if (diff_src_iter_) {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
+ [&](int lay, int dir, int state, int b) {
+ for (int s = 0; s < rnn.sic; s++) {
+ diff_src_iter_[diff_src_iter_d.blk_off(
+ lay, dir, state, b, s)]
+ = ws_diff_states(lay, dir, state, 0, b, s);
+ }
+ });
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) {
+ /* Original set of bias provided by the user */
+ AOC<const float, 5> b(
+ b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
+ /* Array of pointers initialized in packing */
+ AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
+ AOC<float, 3> scratch_bias(
+ scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
+
+ if (rnn.copy_bias) {
+ parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic,
+ [&](size_t i) { scratch_bias_[i] = b_[i]; });
+ }
+
+ for (int i = 0; i < rnn.n_layer; i++) {
+ for (int d = 0; d < rnn.n_dir; d++) {
+ int offset_bias = 0;
+ for (int p = 0; p < rnn.n_parts_bias; p++) {
+ bias(i, d, p) = rnn.copy_bias
+ ? (float *) &scratch_bias(i, d, offset_bias)
+ : (float *) &b(i, d, offset_bias);
+ offset_bias += rnn.parts_bias[p] * rnn.dic;
+ }
+ }
+ }
+
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_bias_finalize_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) {
+ if (rnn.dt_conf != all_f32) {
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+ float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
+ bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
+ for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
+ for (int j = 0; j < rnn.n_bias * rnn.dic; j++) {
+ size_t off = i * rnn.n_bias * rnn.dic + j;
+ float weights_scale
+ = scale_per_oc ? weights_scales[j] : weights_scales[0];
+ scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
+ * data_shift / (weights_scale * data_scale);
+ }
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type,
+ weights_type>::assign_packed_weights)) {
+ assert(md->format_kind == format_kind::rnn_packed);
+ const auto packed_desc = md->format_desc.rnn_packed_desc;
+ AOC<weights_data_t *, 3> weights(weights_,
+ rnn.n_layer, rnn.n_dir, packed_desc.n_parts);
+
+ size_t offset_packed = 0;
+ for (int l = 0; l < rnn.n_layer; l++)
+ for (int d = 0; d < rnn.n_dir; d++) {
+ for (int p = 0; p < packed_desc.n_parts; p++) {
+ weights(l, d, p) = (weights_data_t *)&w_[offset_packed];
+ offset_packed
+ += packed_desc.part_pack_size[p] / sizeof(weights_data_t);
+ }
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_weights_assign_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) {
+ assert(md->format_kind == format_kind::blocked);
+ const auto &blk = md->format_desc.blocking;
+ /* Original set of weights provided by the user */
+ AOC<const weights_data_t, 3> w(w_,
+ rnn.n_layer, rnn.n_dir, (int)blk.strides[1]);
+ /* Array of pointers for each part of weights */
+ AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
+
+ for (int i = 0; i < rnn.n_layer; i++)
+ for (int d = 0; d < rnn.n_dir; d++) {
+ size_t offset_weights = 0;
+ for (int p = 0; p < n_parts; p++) {
+ weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights);
+ offset_weights += gates_per_part[p] * blk.strides[3];
+ }
+ }
+}
+
+//********************* Execution function *********************//
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_(
+ const exec_ctx_t &ctx) const {
+ const rnn_conf_t &rnn = this->pd()->rnn_;
+ auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER);
+ auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER);
+ auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER);
+ auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER);
+ auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
+
+ auto dst_last_layer = rnn.is_fwd
+ ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER)
+ : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER));
+ auto dst_last_iter = rnn.is_fwd
+ ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER)
+ : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER));
+
+ auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER);
+ auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER);
+
+ auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp);
+ auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp);
+ auto w_iter_comp = reinterpret_cast<const float *>(
+ iter_weights_n_comp + rnn.weights_iter_comp_offset);
+ auto w_layer_comp = reinterpret_cast<const float *>(
+ layer_weights_n_comp + rnn.weights_layer_comp_offset);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ auto ptr_wei_layer
+ = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer);
+ auto ptr_wei_iter
+ = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter);
+ auto ptr_bias =
+ scratchpad.template get<float *>(key_rnn_ptrs_bia);
+
+ // fetchihg buffers from the workspace
+ // if no workspace was provided we use the scratchpad
+ char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
+ char *ws_ptr = nullptr;
+ if (rnn.use_workspace)
+ ws_ptr = rnn.is_fwd
+ ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE)
+ : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE));
+
+ char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
+ acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_);
+ src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_);
+ float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_);
+ float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_);
+ float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_);
+ float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_);
+
+ auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER);
+ auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER);
+
+ auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER);
+ auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER);
+ auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
+
+ // Fetching extra buffers from scratchpad
+ float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_);
+
+ // initialize diff_states to 0
+ if (aprop == prop_kind::backward)
+ array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float));
+
+ /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
+ * dimension */
+ (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
+
+ (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1),
+ rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic,
+ rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter,
+ rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter,
+ ptr_bias, bias, ws_bias);
+ (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0),
+ rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc,
+ rnn.n_parts_weights_layer, rnn.parts_weights_layer,
+ rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias,
+ bias, ws_bias);
+
+ (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
+
+ // we first need to copy the initial states and input into ws
+ copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer);
+ if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
+ || rnn.dt_conf == all_f32)
+ copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
+ (const float *)states, diff_dst_iter);
+ else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
+ copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
+ (const uint8_t *)states, diff_dst_iter);
+ else
+ assert(!"unimplemented");
+
+ // run the execution on the grid
+ (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias,
+ ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid,
+ diff_weights_layer, diff_weights_iter, diff_bias);
+
+ // Finally we copy the results to the result buffers
+ if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32
+ || rnn.dt_conf == all_f32)
+ copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states,
+ ws_diff_states);
+ else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8)
+ copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer,
+ ws_states, ws_diff_states);
+ else
+ assert(!"unimplemented");
+
+ if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
+ || rnn.dt_conf == all_f32)
+ copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states,
+ ws_c_states, ws_diff_states);
+ else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
+ copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states,
+ ws_c_states, ws_diff_states);
+ else
+ assert(!"unimplemented");
+};
+
+/* Fix for MSVS warning C4661 */
+template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
+template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
+template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
+template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
+template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise);
+
+template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
+template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
+template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
+
+#undef AOC
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp
new file mode 100644
index 0000000000..6f449a9016
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp
@@ -0,0 +1,328 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_RNN_HPP
+#define CPU_REF_RNN_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "../cpu_isa_traits.hpp"
+#include "../gemm/os_blas.hpp"
+
+#include "cpu_rnn_pd.hpp"
+#include "../cpu_primitive.hpp"
+#include "rnn_utils.hpp"
+#include "jit_uni_rnn_postgemm.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <alg_kind_t alg_kind, prop_kind_t prop_kind>
+float activation(float s, float alpha, float cliping, float dd);
+
+template <prop_kind_t aprop, impl::data_type_t src_type,
+ impl::data_type_t weights_type>
+struct _ref_rnn_common_t : public cpu_primitive_t {
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<weights_type>::type weights_data_t;
+ typedef typename utils::conditional<src_type == data_type::u8, int32_t,
+ float>::type acc_data_t;
+
+ using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>;
+
+ typedef rnn_elemwise_sig((class_name::*elemwise_f));
+ typedef rnn_cell_execution_sig((class_name::*cell_execution_f));
+ typedef rnn_grid_execution_sig((class_name::*grid_execution_f));
+
+ typedef rnn_gemm_sig((class_name::*gemm_t));
+ typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t));
+ typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t));
+ typedef rnn_weights_assign_sig((class_name::*weights_assign_t));
+
+ using base_pd_t =
+ typename utils::conditional<false || aprop == prop_kind::forward,
+ cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
+
+ struct pd_t : public base_pd_t {
+ using base_pd_t::base_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", class_name);
+
+ status_t init() {
+ using namespace prop_kind;
+ using namespace utils;
+ using namespace format_tag;
+ using namespace rnn_utils;
+ const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
+
+ data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type;
+ data_type_t weights_iter_dt
+ = this->desc()->weights_iter_desc.data_type;
+ data_type_t weights_layer_dt
+ = this->desc()->weights_layer_desc.data_type;
+
+ bool ok = true
+ && one_of(cell_kind, alg_kind::vanilla_rnn,
+ alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
+ alg_kind::gru_linear_before_reset)
+ && IMPLICATION(aprop == prop_kind::forward,
+ one_of(this->desc()->prop_kind, forward_training,
+ forward_inference))
+ && IMPLICATION(aprop == backward,
+ one_of(this->desc()->prop_kind, backward))
+ && src_layer_dt == src_type
+ && everyone_is(
+ weights_type, weights_iter_dt, weights_layer_dt)
+ && this->set_default_params() == status::success
+ && this->with_bias();
+ if (!ok)
+ return status::unimplemented;
+
+ init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1),
+ this->weights_md(0), this->weights_md(1), this->dst_md(0));
+
+ if (rnn_.dt_conf == all_f32)
+ ok = ok && this->attr()->has_default_values();
+
+ // Set weights descriptors to desired format
+ memory_desc_t new_weights_layer_md = *this->weights_md(0);
+ CHECK(set_expected_desc(rnn_, new_weights_layer_md, false));
+ if (this->weights_layer_md_.format_kind == format_kind::any) {
+ this->weights_layer_md_ = new_weights_layer_md;
+ } else if (this->weights_layer_md_.format_kind
+ == format_kind::rnn_packed) {
+ if (this->weights_layer_md_ != new_weights_layer_md)
+ return status::unimplemented;
+ }
+
+ memory_desc_t new_weights_iter_md = *this->weights_md(1);
+ CHECK(set_expected_desc(rnn_, new_weights_iter_md, true));
+ if (this->weights_iter_md_.format_kind == format_kind::any) {
+ this->weights_iter_md_ = new_weights_iter_md;
+ } else if (this->weights_iter_md_.format_kind
+ == format_kind::rnn_packed) {
+ if (this->weights_iter_md_ != new_weights_iter_md)
+ return status::unimplemented;
+ }
+
+ CHECK(this->check_layout_consistency());
+
+ set_conf(rnn_, *this->desc(), this->weights_md(0),
+ this->weights_md(1), this->diff_weights_md(0),
+ this->diff_weights_md(1));
+
+ size_t scratchpad_sz{0}, ws_sz{0};
+ get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz);
+
+ // initialize the workspace if needed
+ if (rnn_.is_training) {
+ dims_t ws_dims = { (int)ws_sz };
+ mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims,
+ data_type::u8, format_tag::x);
+ }
+
+ init_scratchpad(scratchpad_sz);
+
+ return status::success;
+ }
+
+ rnn_utils::rnn_conf_t rnn_;
+
+ private:
+ void init_scratchpad(size_t scratchpad_sz) {
+ using namespace memory_tracking::names;
+ auto scratchpad = this->scratchpad_registry().registrar();
+ scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096);
+
+ int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1;
+ int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts;
+ scratchpad.book(key_rnn_ptrs_wei_layer,
+ sizeof(float *) * ptr_wei_sz);
+ scratchpad.book(key_rnn_ptrs_wei_iter,
+ sizeof(float *) * ptr_wei_sz);
+ scratchpad.book(key_rnn_ptrs_bia,
+ sizeof(float *) * ptr_wei_sz);
+ }
+ };
+
+ _ref_rnn_common_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) {
+ /// @todo set max_feature_size assuming that we limit the number of
+ /// iterations and layer to one if slc != dic and sic != dic
+ /// respectively
+
+ bias_preparation_func = &class_name::bias_prepare;
+ bias_finalization_func = &class_name::bias_finalize;
+
+ auto set_gemm_funcs
+ = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) {
+ if (packed_gemm) {
+ g = &class_name::packed_gemm;
+ a = &class_name::assign_packed_weights;
+ } else {
+ g = &class_name::gemm;
+ a = &class_name::assign_weights;
+ }
+ };
+ set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func,
+ weights_iter_assign_func);
+
+ set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
+ weights_layer_assign_func);
+
+ switch (pd()->cell_kind()) {
+ case alg_kind::vanilla_lstm:
+ cell_func = &class_name::cell_execution;
+ if (aprop == prop_kind::forward) {
+ if (mayiuse(avx512_core))
+ rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>(
+ pd()->rnn_, pd()->attr());
+ else if (mayiuse(avx2))
+ rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>(
+ pd()->rnn_, pd()->attr());
+ else if (mayiuse(sse42))
+ rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>(
+ pd()->rnn_, pd()->attr());
+ assert(rnn_postgemm_ != nullptr);
+ rnn_postgemm_->init();
+ }
+ elemwise_func = &class_name::lstm_elemwise;
+ break;
+ case alg_kind::vanilla_rnn: // @todo switch on cell kind
+ cell_func = &class_name::cell_execution;
+ elemwise_func = &class_name::rnn_elemwise;
+ switch (pd()->activation_kind()) {
+ case alg_kind::eltwise_relu:
+ activation_func = &activation<alg_kind::eltwise_relu, aprop>;
+ break;
+ case alg_kind::eltwise_tanh:
+ activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
+ break;
+ case alg_kind::eltwise_logistic:
+ activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
+ break;
+ default: break;
+ }
+ break;
+ case alg_kind::vanilla_gru:
+ cell_func = &class_name::cell_execution_gru;
+ break;
+ case alg_kind::gru_linear_before_reset:
+ cell_func = &class_name::cell_execution_gru_lbr;
+ elemwise_func = &class_name::gru_lbr_elemwise;
+ break;
+ default: break;
+ }
+
+ grid_computation = &class_name::linear_execution;
+
+ size_t scratchpad_size, workspace_size;
+ rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_,
+ ws_c_states_offset_, ws_diff_states_offset_,
+ ws_grid_comp_offset_, ws_cell_comp_offset_,
+ ws_bias_offset_, scratchpad_size, workspace_size);
+ }
+
+ ~_ref_rnn_common_t() {}
+
+ // typedef typename prec_traits::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_(const exec_ctx_t &ctx) const;
+ rnn_grid_execution_sig(linear_execution);
+ rnn_cell_execution_sig(cell_execution);
+ rnn_cell_execution_sig(cell_execution_gru);
+ rnn_cell_execution_sig(cell_execution_gru_lbr);
+ rnn_elemwise_sig(rnn_elemwise);
+ rnn_elemwise_sig(lstm_elemwise);
+ rnn_elemwise_sig(gru_lbr_elemwise);
+ rnn_gemm_sig(gemm);
+ rnn_gemm_sig(packed_gemm);
+ rnn_bias_prepare_sig(bias_prepare);
+ rnn_bias_finalize_sig(bias_finalize);
+ rnn_weights_assign_sig(assign_weights);
+ rnn_weights_assign_sig(assign_packed_weights);
+
+ float (*activation_func)(float dd, float s, float alpha, float cliping);
+
+ void copy_init_layer(const rnn_utils::rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_diff_states_,
+ const src_data_t *xt_, const float *diff_dst_layer) const;
+
+ template <typename input_data_t>
+ void copy_init_iter(const rnn_utils::rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_,
+ const input_data_t *firstit_states_,
+ const float *diff_dst_iter) const;
+
+ template <typename dst_data_t>
+ void copy_res_layer(const rnn_utils::rnn_conf_t &rnn,
+ dst_data_t *dst_layer_, float *diff_src_layer,
+ const src_data_t *ws_states_, const float *ws_diff_states_) const;
+
+ template <typename output_data_t>
+ void copy_res_iter(const rnn_utils::rnn_conf_t &rnn,
+ output_data_t *dst_iter_, float *diff_src_iter,
+ const src_data_t *ws_states_, float *ws_c_states,
+ const float *ws_diff_states_) const;
+
+ void gates_reduction(const rnn_utils::rnn_conf_t &rnn,
+ const acc_data_t *ws_gates_, float *diff_bias_) const;
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ size_t ws_gates_offset_;
+ size_t ws_states_offset_;
+ size_t ws_c_states_offset_;
+ size_t ws_bias_offset_;
+ size_t ws_diff_states_offset_;
+ size_t ws_grid_comp_offset_;
+ size_t ws_cell_comp_offset_;
+ jit_uni_rnn_postgemm_kernel *rnn_postgemm_;
+
+ grid_execution_f grid_computation;
+ cell_execution_f cell_func;
+
+ bias_prepare_t bias_preparation_func;
+ bias_finalize_t bias_finalization_func;
+ weights_assign_t weights_layer_assign_func;
+ weights_assign_t weights_iter_assign_func;
+
+ gemm_t gemm_layer_func;
+ gemm_t gemm_iter_func;
+ elemwise_f elemwise_func;
+};
+
+using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
+using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
+using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
+}
+}
+}
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
new file mode 100644
index 0000000000..597c63e3f8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
@@ -0,0 +1,380 @@
+/*******************************************************************************
+ * Copyright 2018 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#ifndef CPU_RNN_REORDERS_HPP
+#define CPU_RNN_REORDERS_HPP
+
+#include <assert.h>
+
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+#include "simple_q10n.hpp"
+#include "cpu_reorder_pd.hpp"
+#include "../gemm/os_blas.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t type_i, data_type_t type_o>
+struct rnn_data_reorder_t : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == type_i
+ && od.data_type() == type_o
+ && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc)
+ && od == id;
+ if (!args_ok) return status::invalid_arguments;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return out_of_memory;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+ };
+
+private:
+ typedef typename prec_traits<type_i>::type in_data_t;
+ typedef typename prec_traits<type_o>::type out_data_t;
+
+ rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO);
+ const memory_desc_wrapper &input_d = pd()->src_md();
+ const memory_desc_wrapper &output_d = pd()->dst_md();
+ const size_t nelems = input_d.nelems();
+ const float scale = pd()->attr()->rnn_data_qparams_.scale_;
+ const float shift = pd()->attr()->rnn_data_qparams_.shift_;
+
+ parallel_nd(nelems, [&](size_t i) {
+ float in = (float)input[input_d.off_l(i)] * scale + shift;
+ output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in);
+ });
+
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <data_type_t type_i, data_type_t type_o>
+struct rnn_weights_reorder_t : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+#if !USE_MKL_PACKED_GEMM
+ return status::unimplemented;
+#endif
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == type_i
+ && od.data_type() == type_o
+ && od.format_kind() == format_kind::rnn_packed
+ && od.rnn_packed_desc().format == mkldnn_ldigo_p
+ && od.rnn_packed_desc().n_parts == 1
+ && attr != nullptr;
+ if (!args_ok) return status::invalid_arguments;
+
+ format_tag_t itag = id.matches_one_of_tag(
+ format_tag::ldigo, format_tag::ldgoi);
+ if (itag == format_tag::undef) return status::invalid_arguments;
+
+ const int mask = attr->rnn_weights_qparams_.mask_;
+ if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return out_of_memory;
+ _pd->itag_ = itag;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+
+ status_t init() {
+ status_t status = cpu_reorder_pd_t::init();
+ if (status != status::success) return status;
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ format_tag_t itag_;
+
+ private:
+ void init_scratchpad() {
+ const memory_desc_wrapper id(src_md());
+ const size_t nelems = id.nelems();
+ const auto &dims = id.dims();
+
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ size_t quantization_size = sizeof(int8_t) * nelems;
+ size_t reduction_size = itag_ == ldigo
+ ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0]
+ * dims[1] * dims[3] * dims[4]
+ : 0;
+ scratchpad.book(
+ key_reorder_rnn_weights_quantization, quantization_size);
+ scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size);
+ }
+ };
+
+private:
+ typedef typename prec_traits<type_i>::type in_data_t;
+ typedef typename prec_traits<type_o>::type out_data_t;
+
+ rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+#if USE_MKL_PACKED_GEMM
+ auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
+ const memory_desc_wrapper &input_d = pd()->src_md();
+ const memory_desc_wrapper &output_d = pd()->dst_md();
+ const auto &dims = input_d.dims();
+
+ const int L = dims[0];
+ const int D = dims[1];
+ const int I = dims[2];
+ const int G = dims[3];
+ const int O = dims[4];
+
+ const bool is_igo = pd()->itag_ == format_tag::ldigo;
+
+ /* Quantize input & compute compensation */
+ auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>(
+ memory_tracking::names::key_reorder_rnn_weights_quantization);
+ auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>(
+ memory_tracking::names::key_reorder_rnn_weights_reduction);
+ float *comp = reinterpret_cast<float *>(
+ output + output_d.rnn_packed_desc().offset_compensation);
+ const float *scales = pd()->attr()->rnn_weights_qparams_.scales_;
+ const int mask = pd()->attr()->rnn_weights_qparams_.mask_;
+
+ if (is_igo) {
+ int nthr = mkldnn_get_max_threads();
+ int LD_nthr = nstl::min(L * D, nthr);
+ int I_nthr = nstl::min(I, nthr / LD_nthr);
+ parallel(nthr, [&](const int ithr, const int nthr) {
+ int LD_ithr = -1, LD_s = -1, LD_e = -1;
+ int I_ithr = -1, I_s = -1, I_e = -1;
+ if (ithr < LD_nthr * I_nthr) {
+ LD_ithr = ithr % LD_nthr;
+ I_ithr = ithr / LD_nthr;
+ balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
+ balance211(I, I_nthr, I_ithr, I_s, I_e);
+ }
+ int32_t *comp_ithr = reduction + I_ithr * L * D * G * O;
+ for (int ld = LD_s; ld < LD_e; ld++) {
+ for (int go = 0; go < G * O; go++)
+ comp_ithr[ld * G * O + go] = 0;
+ for (int i = I_s; i < I_e; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int go = 0; go < G * O; go++) {
+ const float s = scales[(mask == 0) ? 0 : go];
+ int8_t q = qz_b0<in_data_t, out_data_t>()(
+ input[ld * I * G * O + i * G * O + go], s);
+ quantized[ld * I * G * O + i * G * O + go]
+ = (int32_t)q;
+ comp_ithr[ld * G * O + go] += (int32_t)q;
+ }
+ }
+ }
+ });
+ parallel_nd(L * D * G * O,
+ [&](int s) { comp[s] = saturate<float>(reduction[s]); });
+ for (int i = 1; i < I_nthr; i++) {
+ parallel_nd(L * D * G * O, [&](int s) {
+ comp[s] += saturate<float>(
+ reduction[i * L * D * G * O + s]);
+ });
+ }
+ } else {
+ parallel_nd(L * D, G * O, [&](int ld, int go) {
+ int32_t compensation = 0;
+ const float s = scales[(mask == 0) ? 0 : go];
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < I; i++) {
+ int8_t q = qz_b0<in_data_t, out_data_t>()(
+ input[ld * G * O * I + go * I + i], s);
+ compensation += (int32_t)q;
+ quantized[ld * G * O * I + go * I + i] = q;
+ }
+ comp[ld * G * O + go] = saturate<float>(compensation);
+ });
+ }
+
+ /* Pack */
+ auto off_igo = [&](int l, int d, int i, int g, int o) {
+ return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
+ };
+ auto off_goi = [&](int l, int d, int i, int g, int o) {
+ return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
+ };
+ int n_parts = output_d.rnn_packed_desc().n_parts;
+ const size_t *size_packed_cell
+ = output_d.rnn_packed_desc().part_pack_size;
+ const int *parts = output_d.rnn_packed_desc().parts;
+ const int n = output_d.rnn_packed_desc().n;
+ char *to_pack = output;
+ for (int l = 0; l < L; l++) {
+ for (int d = 0; d < D; d++) {
+ for (int p = 0; p < n_parts; p++) {
+ int g = (p > 0) ? parts[p - 1] : 0;
+ int m_p = parts[p] * O;
+ int k_p = I;
+ cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix,
+ is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p,
+ &quantized[is_igo ? off_igo(l, d, 0, g, 0) :
+ off_goi(l, d, g, 0, 0)],
+ is_igo ? G * O : I, to_pack);
+ to_pack += size_packed_cell[p];
+ }
+ }
+ }
+#endif
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <>
+struct rnn_weights_reorder_t<data_type::f32, data_type::f32>
+ : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+#if !USE_MKL_PACKED_GEMM
+ return status::unimplemented;
+#endif
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == data_type::f32
+ && od.data_type() == data_type::f32
+ && od.format_kind() == format_kind::rnn_packed
+ && utils::one_of(od.rnn_packed_desc().format,
+ mkldnn_ldigo_p, mkldnn_ldgoi_p)
+ && attr->has_default_values();
+ if (!args_ok) return status::invalid_arguments;
+
+ format_tag_t itag = id.matches_one_of_tag(
+ format_tag::ldigo, format_tag::ldgoi);
+ if (itag == format_tag::undef) return status::invalid_arguments;
+
+ const int mask = attr->rnn_weights_qparams_.mask_;
+ if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return out_of_memory;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ _pd->itag_ = itag;
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+
+ format_tag_t itag_;
+ };
+
+private:
+ rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+#if USE_MKL_PACKED_GEMM
+ auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO);
+ const memory_desc_wrapper &input_d = pd()->src_md();
+ const memory_desc_wrapper &output_d = pd()->dst_md();
+ const auto &dims = input_d.dims();
+ const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc();
+ const int L = dims[0];
+ const int D = dims[1];
+ const int I = dims[2];
+ const int G = dims[3];
+ const int O = dims[4];
+
+ /* Pack */
+ bool cross_case = false
+ || (pd()->itag_ == format_tag::ldigo
+ && rnn_pdata.format == mkldnn_ldgoi_p)
+ || (pd()->itag_ == format_tag::ldgoi
+ && rnn_pdata.format == mkldnn_ldigo_p);
+ auto trans = cross_case ? CblasTrans : CblasNoTrans;
+ int n_parts = rnn_pdata.n_parts;
+ const size_t *size_packed_cell = rnn_pdata.part_pack_size;
+ const int *parts = rnn_pdata.parts;
+ const int n = rnn_pdata.n;
+
+ const bool is_igo = pd()->itag_ == format_tag::ldigo;
+ auto off_igo = [&](int l, int d, int i, int g, int o) {
+ return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
+ };
+ auto off_goi = [&](int l, int d, int i, int g, int o) {
+ return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
+ };
+ for (int l = 0; l < L; l++) {
+ for (int d = 0; d < D; d++) {
+ for (int p = 0; p < n_parts; p++) {
+ int g = (p > 0) ? parts[p - 1] : 0;
+ int m_p = is_igo ? parts[p] * O : I;
+ int k_p = is_igo ? I : parts[p] * O;
+ int ld = is_igo ? G * O : I;
+ cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n,
+ k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) :
+ off_goi(l, d, 0, g, 0)],
+ ld, output);
+ output += size_packed_cell[p] / sizeof(float);
+ }
+ }
+ }
+#endif
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+} // namespace cpu
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp
new file mode 100644
index 0000000000..1d60415cbc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp
@@ -0,0 +1,426 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+#include "rnn_utils.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace rnn_utils;
+using namespace format_tag;
+using namespace rnn_packed_format;
+using namespace data_type;
+
+bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) {
+ if (md.format_kind() != format_kind::blocked)
+ return false;
+
+ auto blk = md.blocking_desc();
+ auto str = blk.strides;
+ auto dims = md.dims();
+ return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1
+ && str[3] == dims[4] && str[1] == str[2] * dims[2]
+ && str[0] == str[1] * dims[1];
+};
+
+bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) {
+ if (md.format_kind() != format_kind::blocked)
+ return false;
+
+ auto blk = md.blocking_desc();
+ auto str = blk.strides;
+ auto dims = md.dims();
+ return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1
+ && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3]
+ && str[0] == str[1] * dims[1];
+};
+
+void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &src_layer_d,
+ const memory_desc_wrapper &src_iter_d,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &dst_layer_d) {
+ rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ rnn.is_training = utils::one_of(
+ rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
+ rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset;
+
+ switch (rd.direction) {
+ case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break;
+ case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break;
+ case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break;
+ case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break;
+ default: break;
+ }
+
+ if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
+ weights_layer_d.data_type()))
+ rnn.dt_conf = all_f32;
+ else if (dst_layer_d.data_type() == u8) {
+ if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
+ rnn.dt_conf = u8u8u8u8;
+ else
+ rnn.dt_conf = f32u8f32u8;
+ } else {
+ if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
+ rnn.dt_conf = u8u8u8f32;
+ else
+ rnn.dt_conf = f32u8f32f32;
+ }
+
+ rnn.n_layer = weights_layer_d.dims()[0];
+ rnn.n_iter = src_layer_d.dims()[0];
+ rnn.n_dir = weights_layer_d.dims()[1];
+ rnn.n_gates = weights_layer_d.dims()[3];
+ rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc);
+ rnn.n_bias = rnn.n_gates + rnn.is_lbr;
+ rnn.mb = src_layer_d.dims()[1];
+ rnn.sic = weights_iter_d.dims()[2];
+ rnn.slc = weights_layer_d.dims()[2];
+ rnn.dic = weights_layer_d.dims()[4];
+ rnn.dlc = dst_layer_d.dims()[2];
+
+ rnn.gates_ld = rnn.dic * rnn.n_gates;
+ rnn.gates_nld = rnn.mb;
+ rnn.states_nld = rnn.mb;
+
+ /* Set the correct number of weights parts */
+ bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru;
+ rnn.n_parts_weights_layer = 1;
+ rnn.parts_weights_layer[0] = rnn.n_gates;
+ rnn.parts_weights_layer[1] = 0;
+
+ rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
+ rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
+ rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
+
+ rnn.n_parts_bias = 1;
+ rnn.parts_bias[0] = rnn.n_bias;
+ rnn.parts_bias[1] = 0;
+
+ /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
+ * and if to mergre gemm across iterations */
+ bool is_int8 = rnn.dt_conf != all_f32;
+ rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
+ || is_int8;
+ bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru,
+ alg_kind::gru_linear_before_reset);
+ rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8;
+ bool is_inference = !rnn.is_training;
+
+ rnn.use_jit_gemm = !mayiuse(avx512_mic)
+ && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
+ || (rnn.is_training && rnn.dic < 500));
+
+ /* Decide to copy bias */
+ rnn.copy_bias = rnn.dt_conf != all_f32;
+
+#if USE_MKL_PACKED_GEMM
+ rnn.use_layer_packed_gemm
+ = (weights_layer_d.format_kind() == format_kind::any
+ && rnn.slc > 760 && rnn.dic > 760 && is_inference)
+ || is_int8; // packed gemm is the only supported option for int8
+ rnn.use_iter_packed_gemm
+ = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760
+ && rnn.dic > 760 && is_inference)
+ || is_int8;
+#else
+ rnn.use_layer_packed_gemm = false;
+ rnn.use_iter_packed_gemm = false;
+#endif
+
+ /* Set packed gemm sizes */
+ if (rnn.use_layer_packed_gemm) {
+ rnn.weights_layer_pack_size = 0;
+ for (int p = 0; p < rnn.n_parts_weights_layer; p++) {
+ int m_p = rnn.is_fwd
+ ? (rnn.parts_weights_layer[p] * rnn.dic)
+ : rnn.slc;
+ int k_p = rnn.is_fwd
+ ? rnn.slc
+ : (rnn.parts_weights_layer[p] * rnn.dic);
+ int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb;
+
+#if USE_MKL_PACKED_GEMM
+ if (rnn.dt_conf == all_f32)
+ rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+ else
+ rnn.part_weights_layer_pack_size[p]
+ = cblas_gemm_s8u8s32_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+#else
+ UNUSED(m_p);
+ UNUSED(k_p);
+ UNUSED(n_p);
+ rnn.part_weights_layer_pack_size[p] = 0;
+#endif
+ rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir
+ * rnn.part_weights_layer_pack_size[p];
+ }
+ rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size;
+ rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
+ * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float);
+ }
+
+ if (rnn.use_iter_packed_gemm) {
+ rnn.weights_iter_pack_size = 0;
+ for (int p = 0; p < rnn.n_parts_weights_iter; p++) {
+ int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) :
+ rnn.sic;
+ int k_p = rnn.is_fwd ? rnn.sic :
+ (rnn.parts_weights_iter[p] * rnn.dic);
+ int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb;
+
+#if USE_MKL_PACKED_GEMM
+ if (rnn.dt_conf == all_f32)
+ rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+ else
+ rnn.part_weights_iter_pack_size[p]
+ = cblas_gemm_s8u8s32_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+#else
+ UNUSED(m_p);
+ UNUSED(k_p);
+ UNUSED(n_p);
+ rnn.part_weights_iter_pack_size[p] = 0;
+#endif
+ rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir
+ * rnn.part_weights_iter_pack_size[p];
+ }
+ rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size;
+ rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
+ * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float);
+ }
+
+}
+
+void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &diff_weights_layer_d,
+ const memory_desc_wrapper &diff_weights_iter_d) {
+
+ /* Set leading dimensions for input weights arrays depending on input format
+ */
+ rnn.weights_layer_is_packed
+ = weights_layer_d.format_kind() == format_kind::rnn_packed;
+ rnn.weights_iter_is_packed
+ = weights_iter_d.format_kind() == format_kind::rnn_packed;
+
+ auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
+ ld = 0; nld = 0;
+ if (md.is_blocking_desc()) {
+ if (is_ldigo(md)) {
+ ld = (int)md.blocking_desc().strides[2];
+ nld = md.dims()[2];
+ } else if (is_ldgoi(md)) {
+ ld = (int)md.blocking_desc().strides[4];
+ nld = md.dims()[3] * md.dims()[4];
+ } else
+ assert(!"unsupported weights format");
+ }
+ };
+ set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
+ set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
+ if (!rnn.is_fwd) {
+ set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
+ rnn.diff_weights_layer_nld);
+ set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
+ rnn.diff_weights_iter_nld);
+ }
+
+ int sizeof_states_dt
+ = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t);
+ rnn.states_ws_ld
+ = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)),
+ sizeof_states_dt);
+ rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float));
+
+ /* Set workspace sizes to store:
+ * states to copmute a pass
+ * diff states to copmute bwd pass (training only)
+ * intermediate results from the gates
+ */
+ rnn.use_workspace = rnn.is_training;
+ rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
+ * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt;
+ bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm;
+ rnn.ws_c_states_size = is_lstm
+ ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
+ * rnn.states_ws_ld * sizeof(float)
+ : 0;
+ rnn.ws_diff_states_size = rnn.is_training
+ ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1)
+ * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld
+ * sizeof(float)
+ : (size_t)0;
+ rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
+ * rnn.gates_ws_ld * sizeof(float);
+
+ /* set other sizes */
+ rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float);
+ rnn.ws_cell_comp_size
+ = rnn.is_lbr || rnn.dt_conf != all_f32
+ ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float)
+ : 0;
+ rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
+ * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
+ rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic
+ * sizeof(float);
+}
+
+int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
+ // we want matrices leading dimentions to be 64-byte aligned,
+ // and not divisible by 256 to avoid 4K aliasing effects
+ int ld = rnd_up(dim, 64 / sizeof_dt);
+ return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
+}
+
+void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
+ size_t &ws_states_offset, size_t &ws_c_states_offset,
+ size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
+ size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
+ size_t &scratchpad_size, size_t &workspace_size) {
+
+ const size_t page_size = 4096; // 2097152;
+ size_t current_offset;
+ /* Mandatory workspaces: go to workspace if use_workspace, scratchpad
+ * otherwise */
+ current_offset = 0; // assumes the workspace base pointer is page aligned
+ ws_gates_offset = current_offset;
+ current_offset += rnn.ws_gates_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_states_offset = current_offset;
+ current_offset += rnn.ws_states_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_c_states_offset = current_offset;
+ current_offset += rnn.ws_c_states_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_diff_states_offset = current_offset;
+ current_offset += rnn.ws_diff_states_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_grid_comp_offset = current_offset;
+ current_offset += rnn.ws_grid_comp_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_cell_comp_offset = current_offset;
+ current_offset += rnn.ws_cell_comp_size;
+
+ workspace_size = rnn.use_workspace ? current_offset : 0;
+
+ /* Optional scratchpads */
+ // Assumes the scratchpad base pointer is page aligned.
+ // If use_workspace, the following goes to scratchpad alone,
+ // otherwise, all goes to scratchpad and continue incrementing offset
+ current_offset = rnn.use_workspace ? 0 : current_offset;
+
+ if (rnn.copy_bias) {
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_bias_offset = current_offset;
+ current_offset += rnn.ws_bias_size;
+ }
+
+ scratchpad_size = current_offset;
+}
+
+void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
+ size_t &scratchpad_size, size_t &workspace_size) {
+ size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
+ ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
+ ws_bias_offset;
+ set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
+ ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
+ ws_bias_offset, scratchpad_size, workspace_size);
+}
+
+status_t rnn_utils::set_good_strides(
+ memory_desc_t &weights_md, format_tag_t tag) {
+ auto &strides = weights_md.format_desc.blocking.strides;
+ auto dims = weights_md.dims;
+
+ if (tag == ldigo) {
+ strides[2] = rnn_utils::get_good_ld((int)strides[2],
+ (int)types::data_type_size(weights_md.data_type));
+ strides[1] = dims[2] * strides[2];
+ strides[0] = dims[1] * strides[1];
+ } else if (tag == ldgoi) {
+ strides[4] = rnn_utils::get_good_ld((int)strides[4],
+ (int)types::data_type_size(weights_md.data_type));
+ strides[3] = dims[4] * strides[4];
+ strides[1] = dims[3] * strides[3];
+ strides[0] = dims[1] * strides[1];
+ } else
+ return status::unimplemented;
+
+ return status::success;
+}
+
+status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
+ memory_desc_t &weights_md, bool is_iter) {
+ using namespace format_tag;
+ bool use_packed_gemm = is_iter
+ ? rnn.use_iter_packed_gemm
+ : rnn.use_layer_packed_gemm;
+ if (use_packed_gemm) {
+ weights_md.format_kind = format_kind::rnn_packed;
+ rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc;
+ rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p;
+ if (is_iter) {
+ rnn_pdata.n = rnn.mb;
+ rnn_pdata.n_parts = rnn.n_parts_weights_iter;
+ array_copy(rnn_pdata.parts, rnn.parts_weights_iter,
+ MKLDNN_RNN_MAX_N_PARTS);
+ array_copy(rnn_pdata.part_pack_size,
+ rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS);
+ rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset;
+ rnn_pdata.size = rnn.weights_iter_pack_size;
+ } else {
+ rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb;
+ rnn_pdata.n_parts = rnn.n_parts_weights_layer;
+ array_copy(rnn_pdata.parts, rnn.parts_weights_layer,
+ MKLDNN_RNN_MAX_N_PARTS);
+ array_copy(rnn_pdata.part_pack_size,
+ rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS);
+ rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset;
+ rnn_pdata.size = rnn.weights_layer_pack_size;
+ }
+ } else {
+ CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi));
+ // Adjust strides for good leading dimension in GEMM
+ CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi));
+ }
+ return status::success;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp
new file mode 100644
index 0000000000..99eb787a64
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp
@@ -0,0 +1,225 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef RNN_UTILS_HPP
+#define RNN_UTILS_HPP
+
+#include "mkldnn.h"
+
+#include "cpu_rnn_pd.hpp"
+
+
+#define rnn_elemwise_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \
+ src_data_t *states_t_l_, float *c_states_t_l_, \
+ src_data_t *states_tm1_l_, float *c_states_tm1_l_, \
+ float *diff_states_t_l_, float *diff_states_t_lp1_, \
+ float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \
+ float *ws_cell_) const
+
+#define rnn_cell_execution_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \
+ float *c_states_t_l_, float *diff_states_t_l_, \
+ weights_data_t **w_layer_, weights_data_t **w_iter_, \
+ float **bias_, src_data_t *states_t_lm1_, \
+ src_data_t *states_tm1_l_, float *c_states_tm1_l_, \
+ float *diff_states_t_lp1_, float *diff_states_tp1_l_, \
+ float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \
+ acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const
+
+#define rnn_grid_execution_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \
+ weights_data_t **weights_states_, float **bias_, \
+ src_data_t *ws_states_, float *ws_c_states_, \
+ float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \
+ float *ws_grid_, float *diff_weights_layer_, \
+ float *diff_weights_iter_, float *diff_bias_) const
+
+#define rnn_gemm_sig(f) \
+ void f(const char transA, const char transB, int m, int n, int k, \
+ const float alpha, const weights_data_t *a_, const int ldA, \
+ const src_data_t *b_, const int ldB, const float beta, \
+ acc_data_t *c_, const int ldC) const
+
+#define rnn_bias_prepare_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \
+ float *scratch_bias_) const
+
+#define rnn_bias_finalize_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \
+ const float *w_iter_comp, const float *w_layer_comp) const
+
+#define rnn_weights_assign_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \
+ int ld, int OC_size, int IC_size, const int n_parts, \
+ const int *gates_per_part, const size_t *part_weights_pack_size, \
+ weights_data_t **weights_, const weights_data_t *w_, \
+ float **bias_, const float *b_, float *scratch_bias_) const
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace rnn_utils {
+
+using namespace mkldnn::impl::utils;
+
+enum execution_direction_t {
+ l2r,
+ r2l,
+ bi_concat,
+ bi_sum,
+};
+
+enum data_type_conf_t {
+ all_f32,
+ u8u8u8f32,
+ f32u8f32f32,
+ u8u8u8u8,
+ f32u8f32u8
+};
+
+struct rnn_conf_t {
+ execution_direction_t exec_dir;
+ data_type_conf_t dt_conf;
+ int n_layer, n_iter, n_dir, n_gates, n_states;
+ int mb;
+ int slc, sic, dic, dlc;
+ int gates_ld, gates_nld, gates_ws_ld;
+ int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS];
+ int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS];
+ int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS];
+ size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS],
+ part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS];
+ bool weights_layer_is_packed, weights_iter_is_packed;
+ /* Size of packed data in bytes */
+ size_t weights_layer_comp_offset, weights_layer_pack_size,
+ weights_iter_comp_offset, weights_iter_pack_size;
+
+ bool copy_bias;
+ int weights_layer_ld, weights_layer_nld;
+ int diff_weights_layer_ld, diff_weights_layer_nld;
+ int weights_iter_ld, weights_iter_nld;
+ int diff_weights_iter_ld, diff_weights_iter_nld;
+ int states_nld, states_ws_ld;
+ int weights_iter_compensation_size, weights_layer_compensation_size;
+ bool is_fwd, is_training, is_lbr;
+ bool use_workspace;
+
+ /* Size of workspace for each tensor in bytes */
+ size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size,
+ ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size;
+ bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm,
+ use_iter_packed_gemm;
+};
+
+bool is_ldigo(const memory_desc_wrapper &md);
+bool is_ldgoi(const memory_desc_wrapper &md);
+
+int get_good_ld(int dim, int sizeof_dt);
+
+void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &src_layer_d,
+ const memory_desc_wrapper &src_iter_d,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &dst_layer_d);
+
+void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &diff_weights_layer_d,
+ const memory_desc_wrapper &diff_weights_iter_d);
+
+void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
+ size_t &ws_h_state_offset, size_t &ws_c_state_offset,
+ size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
+ size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
+ size_t &scratchpad_size, size_t &workspace_size);
+
+void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
+ size_t &scratchpad_size, size_t &workspace_size);
+status_t set_expected_desc(
+ rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter);
+status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag);
+
+template <typename T>
+struct ws_gates_aoc {
+ ws_gates_aoc(const rnn_conf_t &rnn, T *data)
+ : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {}
+ T &operator()(int batch, int gate, int dic) {
+ return gates_(batch, gate * DIC_ + dic);
+ }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<T, 2> gates_;
+ int DIC_;
+};
+using ws_gates_aoc_t = ws_gates_aoc<float>;
+using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
+
+struct bias_aoc_t {
+ bias_aoc_t(const rnn_conf_t &rnn, const float *data)
+ : bias_(data, rnn.n_bias, rnn.dic) {}
+ const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<const float, 2> bias_;
+};
+
+template <typename T>
+struct ws_states_aoc {
+ ws_states_aoc(const rnn_conf_t &rnn, T *data)
+ : state_(data, rnn.states_nld, rnn.states_ws_ld) {}
+ T &operator()(int batch, int dic) { return state_(batch, dic); }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<T, 2> state_;
+};
+using ws_states_aoc_t = ws_states_aoc<float>;
+using ws_states_aoc_u8_t = ws_states_aoc<uint8_t>;
+
+struct ws_diff_states_aoc_t {
+ ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data)
+ : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld,
+ rnn.states_ws_ld) {}
+ float &operator()(int state_n, int batch, int dic) {
+ return diff_states_(state_n, 0, batch, dic);
+ }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<float, 4> diff_states_;
+};
+
+struct ws_diff_w_iter_aoc_t {
+ ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
+ : diff_weights_iter_(
+ data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
+ , DIC_(rnn.dic) {}
+ float &operator()(int sic, int gate, int dic) {
+ return diff_weights_iter_(sic, gate * DIC_ + dic);
+ }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<float, 2> diff_weights_iter_;
+ int DIC_;
+};
+}
+}
+}
+}
+#endif