diff options
author | RĂ©mi Verschelde <rverschelde@gmail.com> | 2020-05-11 13:45:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-11 13:45:48 +0200 |
commit | 32133a11b56761df99579ad96ee29a47d2aed6b4 (patch) | |
tree | ab68992cfe6b1f59a618f713545fdcb3b6488b07 /thirdparty/oidn/mkl-dnn/src/cpu/rnn | |
parent | bbdfc7353c3af72fcdf037ff10b8571aa2afc230 (diff) | |
parent | 1bea8e1eacc68bcedbd3f207395bccf11011dae2 (diff) |
Merge pull request #38386 from reduz/new-lightmapper
New GPU lightmapper
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/rnn')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp | 90 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp | 180 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp | 170 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp | 143 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp | 113 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp | 191 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp | 401 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp | 788 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp | 328 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp | 380 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp | 426 | ||||
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp | 225 |
12 files changed, 3435 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp new file mode 100644 index 0000000000..537084db91 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Common for RNN and LSTM cell execution + */ +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +using namespace rnn_utils; + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_cell_execution_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execution)) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + if (rnn_postgemm_ != nullptr) + rnn_postgemm_->execute<src_data_t, acc_data_t>(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + else + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} +template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) { + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + /// bwd by data on the cell + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, + 0.0, diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + + /// bwd by weights on the cell + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + + if (!rnn.merge_gemm_iter) + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, + diff_w_iter_, rnn.diff_weights_iter_ld); + + /// bwd by bias we just accumulate diffs from the gates + gates_reduction(rnn, ws_gates_, diff_bias_); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp new file mode 100644 index 0000000000..e1a61d4c62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +#define AOC array_offset_calculator +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_[0]); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + // 1. gemm Wx[0-2],x + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + + // 2. gemm Wh[0-1],h + (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb, + rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + // 3. activation zt and rt + elemwise multiplication rt,ht-1 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); + } + }); + + // 4. gemm Wh[2],h~t + (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1], + rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0, + &(ws_gates(0, 2, 0)), rnn.gates_ws_ld); + + // 5. activation h~t + calculate ht + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) { + assert(!"GRU int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + // use state memory for intermediate computations + // TODO: use cell ws for that + float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0)); + float *hG1_ = dhG1_; + AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld); + AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld); + + // 1. calculate dG2, dG1, and part of dht-1 + // dG2^ = dh * (1 - G0) * (1 - G2^2) + // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) + // dht-1 (part) = dh * G0 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt + * one_m_square(ws_gates(i, 2, j)); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 0, j) = dG0; + ws_gates(i, 2, j) = dG2; + } + }); + + // 2. calculate intermediate d(hG1) + // d(hG1) = dG2 * W2h^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1], + rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0, + dhG1_, rnn.states_ws_ld); + + // 3. calculate dG1^ and part of dht-1 + // dG1^ = d(hG1) * h * G1 * (1 - G1) + // dht-1 (part) += d(hG1) * G1 + // h * G1 (required for dWh) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float G1 = ws_gates(i, 1, j); + diff_states_t_l(0, i, j) += dhG1(i, j) * G1; + ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1); + hG1(i, j) = G1 * h; + } + }); + + // 4. calculate diff weights + // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) + gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_iter_ld); + gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)), + rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0, + &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld); + + // 5. calculate diff states + // dht-1 += dG1 * W1h + dG0 * W0h + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, + (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0], + rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0, + diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + // dWx += [dG0 dG1 dG2] * [x] + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld); + } + + // 6. calculate diff bias + gates_reduction(rnn, ws_gates_, diff_bias_); +} +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp new file mode 100644 index 0000000000..8dea8c90a4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU with linear before reset + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_); + AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); + ws_gates(i, 0, j) = logistic_fwd( + ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd( + ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd( + ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + if (rnn.is_training) + ws_Wh_b(i, j) = Wh_b; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + // 1. calculate dG1 dG2 dG3 + // dG0 = (dht - G2) * dht * (1 - G0) * G0 + // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 + // dG2 = (1 - G0) * dht * (1 - G2*G2) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + float dG2 = (1.0f - ws_gates(i, 0, j)) + * one_m_square(ws_gates(i, 2, j)) * dHt; + float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 2, j) = dG2; + ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j); + ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; + ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + if (!rnn.merge_gemm_layer) { + // dx = dG * Wx^t + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + // dWx += dG^t * x + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + // dh += dGr * Wh^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld, + 1.0, diff_states_t_l_, rnn.states_ws_ld); + + // dWh += dGr^t * h + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_layer_ld); + + // db1-3 += e * dG + // db4 += e * (r * dG2) + gates_reduction(rnn, ws_gates_, diff_bias_); + + parallel_nd(rnn.dic, [&](int j) { + for (int i = 0; i < rnn.mb; i++) { + diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j); + } + }); +} + +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp new file mode 100644 index 0000000000..a15ba00d4c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "../simple_q10n.hpp" +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j)); + + float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j) + + ws_gates(i, 0, j) * ws_gates(i, 2, j); + states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) { + ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_u8_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + auto q_d = [&](float f) { + float qf = f * data_scale + data_shift; + return qz_a1b0<float, src_data_t>()(qf); + }; + + auto deq_w = [&](acc_data_t s, int gate, int j) { + return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ? + saturate<float>(s) * (1.f / (weights_scales[0] * data_scale)) : + saturate<float>(s) * (1.f / (weights_scales[gate * rnn.dic + j] + * data_scale)); + }; + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float G0 = logistic_fwd<float>( + deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j)); + float G1 = logistic_fwd<float>( + deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j)); + float G2 = tanh_fwd<float>( + deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j)); + float G3 = logistic_fwd<float>( + deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j)); + float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2; + states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp)); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Ct = c_states_t_l(i, j); + /// @todo save it in the workspace in fwd pass or recompute it to + /// save bw + float tanhCt = tanh_fwd(Ct); + // we have 2 incoming diffs on Ht + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dCt = diff_states_tp1_l(1, i, j) + + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; + + float dG1 = c_states_tm1_l(i, j) * dCt + * x_m_square(ws_gates(i, 1, j)); + float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); + float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); + float dG2 + = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); + + diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j); + + ws_gates(i, 0, j) = dG0; + ws_gates(i, 1, j) = dG1; + ws_gates(i, 2, j) = dG2; + ws_gates(i, 3, j) = dG3; + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp new file mode 100644 index 0000000000..4536e8dfad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution of Vanilla RNN + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +float activation<alg_kind::eltwise_relu, prop_kind::forward>( + float dd, float s, float alpha, float cliping) { + return relu_fwd<float>(s, alpha); +} + +template <> +float activation<alg_kind::eltwise_relu, prop_kind::backward>( + float dd, float s, float alpha, float cliping) { + return relu_bwd<float>(dd, s, alpha); +} + +template <> +float activation<alg_kind::eltwise_tanh, prop_kind::forward>( + float dd, float s, float alpha, float cliping) { + return tanh_fwd<float>(s); +} + +template <> +float activation<alg_kind::eltwise_tanh, prop_kind::backward>( + float dd, float s, float alpha, float cliping) { + return dd * one_m_square<float>(s); +} + +template <> +float activation<alg_kind::eltwise_logistic, prop_kind::forward>( + float dd, float s, float alpha, float cliping) { + return logistic_fwd<float>(s); +} + +template <> +float activation<alg_kind::eltwise_logistic, prop_kind::backward>( + float dd, float s, float alpha, float cliping) { + return dd * x_m_square<float>(s); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; j++) { + const float h + = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); + ws_gates(i, 0, j) = states_t_l(i, j) = h; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) { + assert(!"VANILLA RNN int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; ++j) { + const float dH = diff_states_t_lp1(rnn.n_states, i, j) + + diff_states_tp1_l(0, i, j); + auto g = ws_gates(i, 0, j); + ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp new file mode 100644 index 0000000000..b39427caf9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp @@ -0,0 +1,191 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RNN_PD_HPP +#define CPU_RNN_PD_HPP + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "rnn_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "rnn_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { + using rnn_fwd_pd_t::rnn_fwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace data_type; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + /* Int8 is supported only for packed weights */ + data_type_t weights_iter_dt = weights_iter_md_.data_type; + data_type_t weights_layer_dt = weights_layer_md_.data_type; + ok = ok && IMPLICATION( + weights_iter_dt == s8, weights_iter_md_.format_kind + == format_kind::rnn_packed); + ok = ok && IMPLICATION( + weights_layer_dt == s8, weights_layer_md_.format_kind + == format_kind::rnn_packed); + + return ok ? status::success : status::unimplemented; + } +}; + +struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { + using rnn_bwd_pd_t::rnn_bwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + if (diff_src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); + if (diff_weights_layer_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); + } + if (diff_weights_iter_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); + } + if (diff_dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc)); + if (with_bias() && diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); + if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + ok = ok && is_blocked(diff_src_layer_md_, 3) + && is_blocked(diff_dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_), + is_blocked(diff_src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), + is_blocked(diff_dst_iter_md_, 5)); + + ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) + && rnn_utils::is_ldigo(&diff_weights_iter_md_); + ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_), + memory_desc_matches_tag(diff_bias_md_, ldgo)); + + return ok ? status::success : status::unimplemented; + } +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp new file mode 100644 index 0000000000..09445648aa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp @@ -0,0 +1,401 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "rnn_utils.hpp" +#include "../jit_generator.hpp" +#include "../jit_uni_eltwise.hpp" +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "mkldnn_thread.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_uni_rnn_postgemm_kernel : public jit_generator { + + typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_, + void *c_states_t_l_, void *c_states_tm1_l_); + + jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){} + + virtual void init() = 0; + +template <typename src_data_t, typename acc_data_t> + rnn_elemwise_sig(execute) { + rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_); + rnn_utils::bias_aoc_t bias(rnn, bias_); + rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_); + rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + // Todo: add parallelization on dic for the batch 1 case + // Assumption: the kernel runs a loop on dic elements + parallel_nd(rnn.mb, [&](int i) { + auto b_ = &bias(0, 0); + auto g_ = &ws_gates(i, 0, 0); + auto s_tl_ = &states_t_l(i, 0); + auto c_tl_ = &c_states_t_l(i, 0); + auto c_tm1l_ = &c_states_tm1_l(i, 0); + kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_); + }); + } + +protected: + kernel_t kernel_; + const rnn_utils::rnn_conf_t &rnn_; + const primitive_attr_t *attr_; +}; + +template <cpu_isa_t isa, impl::data_type_t src_data_t> +struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd) + + typedef typename utils::conditional<src_data_t == data_type::u8, int32_t, + float>::type acc_data_t; + typedef typename utils::conditional<isa == avx512_core, + jit_uni_eltwise_injector_f32<avx512_common>, + jit_uni_eltwise_injector_f32<isa>>::type injector_t; + + jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr) + : jit_uni_rnn_postgemm_kernel(rnn, attr){} + + void init() override { + // we use rax for both constant tables as they use the same table + sigmoid_injector_ = new injector_t(this, + alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax); + tanh_injector_ = new injector_t(this, + alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax); + generate(); + kernel_ = (kernel_t) this->getCode(); + } + +protected: + injector_t *sigmoid_injector_; + injector_t *tanh_injector_; + + // register size in bytes + using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; + size_t vlen = cpu_isa_traits<isa>::vlen; + size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen; + size_t cstate_dt_size = sizeof(float); + size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float); + size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float); + size_t qscale_dt_size = sizeof(float); + size_t bias_dt_size = sizeof(float); + + void generate() { + using namespace Xbyak; + + int mask = attr_->rnn_weights_qparams_.mask_; + float *weights_scales = attr_->rnn_weights_qparams_.scales_; + float data_scale = attr_->rnn_data_qparams_.scale_; + float data_shift = attr_->rnn_data_qparams_.shift_; + + // Labels declaration + Label vector_loop_start_label, vector_loop_end_label; + Label rem_loop_start_label, rem_loop_end_label; + Label table_label; + + // Register map + Reg64 loop_cnt(r11); // loop counter + Reg64 table_reg(rbx); // table is used for data scale and shifts + Reg64 weights_scales_reg(r13); + // We skip vmm0 as it can be used by the injector for masks on sse4.2 + Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7); + + // constant table map + Address dscale_off_addr = ptr[table_reg]; + Address dshift_off_addr = ptr[table_reg + vlen]; + Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen]; + Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen]; + + // quantize from float to u8 + auto q_d = [&](Vmm f, Vmm tmp_vmm) { + uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm); + uni_vmulps(f, f, dscale_off_addr); // apply scale + uni_vaddps(f, f, dshift_off_addr); // apply shift + uni_vcvtps2dq(f, f); // convert to int32 + uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16 + uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation + // Note that the results are interleaved by 128 bit chunks, so we need to merge them together + switch (vlen) { + case 64: { //avx512 + Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx()); + uni_vmovups(tmpz, zmm_perm_mask_addr); + vpermd(fz, tmpz, fz); + break; } + case 32: { //avx + Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx()); + uni_vmovups(tmpy, ymm_perm_mask_addr); + vpermd(fy, tmpy, fy); + break; } + case 16: // sse: nothing to do + break; + default: assert(!"Unsupported case"); + }; + }; + + auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) { + if (packed) + uni_vrcpps(tmp, s); + else + uni_vrcpss(tmp, s); // prevent divide by zero + // we add one Newton iteration + uni_vmulps(s, s, tmp); + uni_vmulps(s, s, tmp); // s <- s * tmp^2 + uni_vaddps(tmp, tmp, tmp); + uni_vsubps(tmp, tmp, s); + uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 + }; + + // dequantize from s32 to float + auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) { + // TODO: if mask is 0 precompute mul and inverse + if (mask == 0) + uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); + else + uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]); + uni_vcvtdq2ps(s, s); + uni_vmulps(tmp1, tmp1, dscale_off_addr); + fast_recip(tmp1, tmp2, packed); + uni_vmulps(s, s, tmp1); + }; + + // We start code generations here + preamble(); + + // extract addresses passed as parameter +#ifdef _WIN32 + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = r10; + // Here we cannot use rbp to have initial stack pointer so we + // use rsp and offset it with the size of pushed registers in + // preamble + mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]); +#else + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = abi_param5; +#endif + + // initialize registers with addresses and constants + mov(table_reg, table_label); + mov(weights_scales_reg, size_t(weights_scales)); + // both sigmoid and tanh use the same table so load address just once in rax + sigmoid_injector_->load_table_addr(); + + mov(loop_cnt, rnn_.dic * gate_dt_size); + cmp(loop_cnt, vlen); + jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + + L(vector_loop_start_label); + { + // load G0 G1 G2 G3 + uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true); + } + + // add biases + uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0.getIdx()); + sigmoid_injector_->compute_vector(G1.getIdx()); + tanh_injector_->compute_vector(G2.getIdx()); + sigmoid_injector_->compute_vector(G3.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0 * G2 + uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1_vmm, tmp1_vmm, G1); + uni_vfmadd231ps(tmp1_vmm, G0, G2); + uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1_vmm.getIdx()); + uni_vmulps(tmp1_vmm, tmp1_vmm, G3); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, vlen); + add(addr_bias_reg, vlen); + add(addr_states_t_l_reg, vlen_dst); + add(addr_c_states_tm1_l_reg, vlen); + add(addr_c_states_t_l_reg, vlen); + if (mask != 0) + add(weights_scales_reg, vlen); + + // increment loop counter + sub(loop_cnt, vlen); + cmp(loop_cnt, vlen); + jge(vector_loop_start_label); + } + L(vector_loop_end_label); + + cmp(loop_cnt, 0); + je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + // Same code as above, we just use movuss for accessing inputs + // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar + L(rem_loop_start_label); + { + // remaping registers to Xmms + Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx()); + Xmm tmp1s_vmm(tmp1_vmm.getIdx()); + + // load G0 G1 G2 G3 + uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false); + } + + // add biases + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0s, G0s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1s, G1s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2s, G2s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3s, G3s, tmp1s_vmm); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0s.getIdx()); + sigmoid_injector_->compute_vector(G1s.getIdx()); + tanh_injector_->compute_vector(G2s.getIdx()); + sigmoid_injector_->compute_vector(G3s.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0s * G2 + uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s); + uni_vfmadd231ps(tmp1s_vmm, G0s, G2s); + uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1s_vmm.getIdx()); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, gate_dt_size); + add(addr_bias_reg, bias_dt_size); + add(addr_states_t_l_reg, hstate_dt_size); + add(addr_c_states_tm1_l_reg, cstate_dt_size); + add(addr_c_states_t_l_reg, cstate_dt_size); + if (mask != 0) + add(weights_scales_reg, qscale_dt_size); + + // increment loop counter + sub(loop_cnt, gate_dt_size); + cmp(loop_cnt, 0); + jg(rem_loop_start_label); + + } + L(rem_loop_end_label); + + postamble(); + + // Again, only one table is needed and shared between sigmoid and tanh + sigmoid_injector_->prepare_table(false); + tanh_injector_->prepare_table(true); + + L(table_label); + { + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale)); + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift)); + // perm mask for ymm + dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7); + // perm mask for zmm + dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7); + dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14); + } + } + +}; + +template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>; + +template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp new file mode 100644 index 0000000000..ead536816c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp @@ -0,0 +1,788 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + General architecture + + for diff states, we have n_states + 1 as we have n_states diff + to propagate to the previous iteration and 1 states to propagate + to the previous layer + index 0 is dh for cell(t-1, l) to consume + index 1 is dc for cell(t-1, l) to consume + index 2 is dh for cell(t, l-1) to consume + this indexing enables to have the same indexing for states in elemwise + function + only the cell execution function should be impacted + + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "../gemm/gemm.hpp" +#include "../simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::memory_tracking::names; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction( + const rnn_conf_t &rnn, const acc_data_t *ws_gates_, + float *diff_bias_) const { + auto body = [&](int i, int k) { + for (int j = 0; j < rnn.mb; j++) + diff_bias_[i * rnn.dic + k] + += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; + }; + + // @todo block k on simd-width +#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ + /* icc 17.0 has a problem with simd collapse */ \ + && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) +#pragma omp parallel for simd collapse(2) + for (int i = 0; i < rnn.n_gates; i++) + for (int k = 0; k < rnn.dic; k++) + body(i, k); +#else + parallel_nd(rnn.n_gates, rnn.dic, body); +#endif +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) { + assert(ldA * ldB * ldC != 0); + extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, + &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { + assert(!"non packed gemm is disabled for int8"); +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + assert(transA == 'N'); + cblas_sgemm_compute(CblasColMajor, CblasPacked, + (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, + ldB, beta, c_, ldC); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + int8_t offseta = 0, offsetb = 0; + int32_t offsetc = 0; + cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, + CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, + ldB, offsetb, beta, c_, ldC, &offsetc); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +//*************** Grid computations strategy: linear ***************// +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_grid_execution_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) { + AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, + rnn.states_nld * rnn.states_ws_ld); + AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, + rnn.gates_nld * rnn.gates_ws_ld); + AOC<weights_data_t *, 3> weights_input( + weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); + AOC<weights_data_t *, 3> weights_states( + weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); + AOC<float*, 3> bias( + bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer, + rnn.n_dir, + rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); + AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, + rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); + AOC<float, 3> diff_bias( + diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + AOC<float, 4> ws_grid( + ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); + + // We run the grid of computation + for (int dir = 0; dir < rnn.n_dir; dir++) { + for (int j = 0; j < rnn.n_layer; j++) { + int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; + + if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, + rnn.mb * rnn.n_iter, rnn.slc, 1.0, + weights_input(lay, dir, 0), rnn.weights_iter_ld, + &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, + &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); + } + + for (int i = 0; i < rnn.n_iter; i++) { + int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; + (this->*cell_func)(rnn, + &(ws_states(lay + 1, dir, iter + 1, 0)), + &(ws_c_states(lay + 1, dir, iter + 1, 0)), + &(ws_diff_states(lay, dir, 0, iter, 0)), + &(weights_input(lay, dir, 0)), + &(weights_states(lay, dir, 0)), + &(bias(lay, dir, 0)), + &(ws_states(lay, dir, iter + 1, 0)), + &(ws_states(lay + 1, dir, iter, 0)), + &(ws_c_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay + 1, dir, 0, iter, 0)), + &(ws_diff_states(lay, dir, 0, iter + 1, 0)), + &(diff_weights_layer(lay, dir, 0)), + &(diff_weights_iter(lay, dir, 0)), + &(diff_bias(lay, dir, 0)), + &(ws_gates(lay, dir, iter, 0)), + &(ws_grid(lay, dir, iter, 0)), + ws_cell_); + } + + if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, + rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), + rnn.weights_layer_ld, + (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, 0.0, + (acc_data_t *)(&(ws_diff_states( + lay, dir, rnn.n_states, 0, 0))), + rnn.states_ws_ld); + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay, dir, 1, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), + rnn.diff_weights_layer_ld); + } + if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), + rnn.diff_weights_iter_ld); + } + } + } +} + +//********* GRID computations strategy: utility functions **********// + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, + const float *__restrict diff_dst_layer_) const { + + AOC<src_data_t, 4> ws_states( + ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto xt_d = memory_desc_wrapper(pd()->src_md(0)); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto xxt = xt_ + xt_d.blk_off(it, b); + src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); + src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); + if (rnn.exec_dir != r2l) + for (int c = 0; c < rnn.slc; c++) + ws_l2r_ptr[c] = xxt[c]; + if (rnn.exec_dir != l2r) + for (int c = 0; c < rnn.slc; c++) + ws_r2l_ptr[c] = xxt[c]; + }); +} + +template <> +void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, + const float *diff_dst_layer_) const { + AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); + + switch (rnn.exec_dir) { + case bi_concat: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[rnn.dic + s]; + } + }); + break; + case bi_sum: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case l2r: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case r2l: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x = diff_dst_layer_ + + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + default: assert(!"Unsupported direction"); break; + } +} + +/* For int8 configuration, input iteration states may be of types f32 or u8 + * Internally h_state is always stored in u8 and c_state is always stored in f32 + * If input states are of type u8 then h state is copied and c state is dequantized + * If input states are of type f32 then h state is quantized and c_state is copied + * */ +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +template <typename input_data_t> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_c_states_, float *__restrict ws_diff_states_, + const input_data_t *__restrict firstit_states_, + const float *__restrict diff_dst_iter_) const { + AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](input_data_t f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0<float, src_data_t>()(qf); + } else + return (src_data_t)f; + }; + + const bool dequantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::u8; + auto maybe_deq = [&](input_data_t s) { + if (dequantize) + return (((float)s - data_shift) / data_scale); + else + return (float)s; + }; + auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); + if (firstit_states_) { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.sic; s++) + ws_states(lay + 1, dir, 0, b, s) = maybe_q( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 0, b, s)]); + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.sic; s++) + ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 1, b, s)]); + }); + } else { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int j = 0; j < rnn.sic; j++) { + ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; + ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; + } + }); + } +} + +template <> +template <typename input_data_t> +void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter_) const { + AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); + if (diff_dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + array_copy(&(ws_diff_states( + lay, dir, state, rnn.n_iter, b, 0)), + diff_dst_iter_ + + diff_dst_iter_d.blk_off( + lay, dir, state, b), + rnn.dic); + }); + } else { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int i) { + for (int j = 0; j < rnn.dic; j++) + ws_diff_states(lay, dir, state, rnn.n_iter, i, j) + = 0.0f; + }); + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +template <typename dst_data_t> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + + auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); + AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float shift = (pd()->attr()->rnn_data_qparams_.shift_); + float scale = (pd()->attr()->rnn_data_qparams_.scale_); + + const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (dst_data_t)(((float)s - shift) / scale); + else + return (dst_data_t)s; + }; + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + if (rnn.exec_dir != r2l) { + for (int s = 0; s < rnn.dic; s++) { + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); + } + dir = 1; + } + if (rnn.exec_dir != l2r) { + for (int s = 0; s < rnn.dic; s++) + switch (rnn.exec_dir) { + case bi_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] + += maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + break; + default: + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + } + } + }); +} + +template <> +template <typename dst_data_t> +void ref_rnn_bwd_f32_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); + AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + for (int s = 0; s < rnn.slc; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, + b, dir * rnn.slc + s); + float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); + if (rnn.n_dir - 1) + res += ws_diff_states( + 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); + dst_addr[0] = res; + } + }); +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +template <typename output_data_t> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); + AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::u8 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](float f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0<float, output_data_t>()(qf); + } else + return (output_data_t)f; + }; + + const bool dequantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (output_data_t)(((float)s - data_shift) / data_scale); + else + return (output_data_t)s; + }; + if (dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, + [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] + = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); + } + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] + = maybe_q(ws_c_states( + lay + 1, dir, rnn.n_iter, b, s)); + } + }); + } +} + +template <> +template <typename output_data_t> +void ref_rnn_bwd_f32_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); + AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + if (diff_src_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < rnn.sic; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, state, 0, b, s); + } + }); + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) { + /* Original set of bias provided by the user */ + AOC<const float, 5> b( + b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + /* Array of pointers initialized in packing */ + AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC<float, 3> scratch_bias( + scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + + if (rnn.copy_bias) { + parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, + [&](size_t i) { scratch_bias_[i] = b_[i]; }); + } + + for (int i = 0; i < rnn.n_layer; i++) { + for (int d = 0; d < rnn.n_dir; d++) { + int offset_bias = 0; + for (int p = 0; p < rnn.n_parts_bias; p++) { + bias(i, d, p) = rnn.copy_bias + ? (float *) &scratch_bias(i, d, offset_bias) + : (float *) &b(i, d, offset_bias); + offset_bias += rnn.parts_bias[p] * rnn.dic; + } + } + } + +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_bias_finalize_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) { + if (rnn.dt_conf != all_f32) { + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; + for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) + for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { + size_t off = i * rnn.n_bias * rnn.dic + j; + float weights_scale + = scale_per_oc ? weights_scales[j] : weights_scales[0]; + scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) + * data_shift / (weights_scale * data_scale); + } + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, + weights_type>::assign_packed_weights)) { + assert(md->format_kind == format_kind::rnn_packed); + const auto packed_desc = md->format_desc.rnn_packed_desc; + AOC<weights_data_t *, 3> weights(weights_, + rnn.n_layer, rnn.n_dir, packed_desc.n_parts); + + size_t offset_packed = 0; + for (int l = 0; l < rnn.n_layer; l++) + for (int d = 0; d < rnn.n_dir; d++) { + for (int p = 0; p < packed_desc.n_parts; p++) { + weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; + offset_packed + += packed_desc.part_pack_size[p] / sizeof(weights_data_t); + } + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_weights_assign_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) { + assert(md->format_kind == format_kind::blocked); + const auto &blk = md->format_desc.blocking; + /* Original set of weights provided by the user */ + AOC<const weights_data_t, 3> w(w_, + rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); + /* Array of pointers for each part of weights */ + AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); + + for (int i = 0; i < rnn.n_layer; i++) + for (int d = 0; d < rnn.n_dir; d++) { + size_t offset_weights = 0; + for (int p = 0; p < n_parts; p++) { + weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); + offset_weights += gates_per_part[p] * blk.strides[3]; + } + } +} + +//********************* Execution function *********************// +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_( + const exec_ctx_t &ctx) const { + const rnn_conf_t &rnn = this->pd()->rnn_; + auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); + auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); + auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); + auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + + auto dst_last_layer = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) + : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); + auto dst_last_iter = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) + : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); + + auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); + auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); + + auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp); + auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp); + auto w_iter_comp = reinterpret_cast<const float *>( + iter_weights_n_comp + rnn.weights_iter_comp_offset); + auto w_layer_comp = reinterpret_cast<const float *>( + layer_weights_n_comp + rnn.weights_layer_comp_offset); + + auto scratchpad = this->scratchpad(ctx); + + auto ptr_wei_layer + = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer); + auto ptr_wei_iter + = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter); + auto ptr_bias = + scratchpad.template get<float *>(key_rnn_ptrs_bia); + + // fetchihg buffers from the workspace + // if no workspace was provided we use the scratchpad + char *scratch_ptr = scratchpad.template get<char>(key_rnn_space); + char *ws_ptr = nullptr; + if (rnn.use_workspace) + ws_ptr = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) + : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); + + char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; + acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); + src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); + float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); + float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); + float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); + float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); + + auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); + auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); + + auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); + auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + // Fetching extra buffers from scratchpad + float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); + + // initialize diff_states to 0 + if (aprop == prop_kind::backward) + array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); + + /* Pack(if using packed gemm API) or copy(if input arrays have bad leading + * dimension */ + (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); + + (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), + rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, + rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, + rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, + ptr_bias, bias, ws_bias); + (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), + rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, + rnn.n_parts_weights_layer, rnn.parts_weights_layer, + rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, + bias, ws_bias); + + (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); + + // we first need to copy the initial states and input into ws + copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const float *)states, diff_dst_iter); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const uint8_t *)states, diff_dst_iter); + else + assert(!"unimplemented"); + + // run the execution on the grid + (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, + ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, + diff_weights_layer, diff_weights_iter, diff_bias); + + // Finally we copy the results to the result buffers + if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, + ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) + copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, + ws_states, ws_diff_states); + else + assert(!"unimplemented"); + + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else + assert(!"unimplemented"); +}; + +/* Fix for MSVS warning C4661 */ +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); + +template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>; +template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>; +template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>; + +#undef AOC +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp new file mode 100644 index 0000000000..6f449a9016 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp @@ -0,0 +1,328 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_RNN_HPP +#define CPU_REF_RNN_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "../cpu_isa_traits.hpp" +#include "../gemm/os_blas.hpp" + +#include "cpu_rnn_pd.hpp" +#include "../cpu_primitive.hpp" +#include "rnn_utils.hpp" +#include "jit_uni_rnn_postgemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <alg_kind_t alg_kind, prop_kind_t prop_kind> +float activation(float s, float alpha, float cliping, float dd); + +template <prop_kind_t aprop, impl::data_type_t src_type, + impl::data_type_t weights_type> +struct _ref_rnn_common_t : public cpu_primitive_t { + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<weights_type>::type weights_data_t; + typedef typename utils::conditional<src_type == data_type::u8, int32_t, + float>::type acc_data_t; + + using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>; + + typedef rnn_elemwise_sig((class_name::*elemwise_f)); + typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); + typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); + + typedef rnn_gemm_sig((class_name::*gemm_t)); + typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); + typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); + typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); + + using base_pd_t = + typename utils::conditional<false || aprop == prop_kind::forward, + cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type; + + struct pd_t : public base_pd_t { + using base_pd_t::base_pd_t; + + DECLARE_COMMON_PD_T("ref:any", class_name); + + status_t init() { + using namespace prop_kind; + using namespace utils; + using namespace format_tag; + using namespace rnn_utils; + const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; + + data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; + data_type_t weights_iter_dt + = this->desc()->weights_iter_desc.data_type; + data_type_t weights_layer_dt + = this->desc()->weights_layer_desc.data_type; + + bool ok = true + && one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset) + && IMPLICATION(aprop == prop_kind::forward, + one_of(this->desc()->prop_kind, forward_training, + forward_inference)) + && IMPLICATION(aprop == backward, + one_of(this->desc()->prop_kind, backward)) + && src_layer_dt == src_type + && everyone_is( + weights_type, weights_iter_dt, weights_layer_dt) + && this->set_default_params() == status::success + && this->with_bias(); + if (!ok) + return status::unimplemented; + + init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1), + this->weights_md(0), this->weights_md(1), this->dst_md(0)); + + if (rnn_.dt_conf == all_f32) + ok = ok && this->attr()->has_default_values(); + + // Set weights descriptors to desired format + memory_desc_t new_weights_layer_md = *this->weights_md(0); + CHECK(set_expected_desc(rnn_, new_weights_layer_md, false)); + if (this->weights_layer_md_.format_kind == format_kind::any) { + this->weights_layer_md_ = new_weights_layer_md; + } else if (this->weights_layer_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_layer_md_ != new_weights_layer_md) + return status::unimplemented; + } + + memory_desc_t new_weights_iter_md = *this->weights_md(1); + CHECK(set_expected_desc(rnn_, new_weights_iter_md, true)); + if (this->weights_iter_md_.format_kind == format_kind::any) { + this->weights_iter_md_ = new_weights_iter_md; + } else if (this->weights_iter_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_iter_md_ != new_weights_iter_md) + return status::unimplemented; + } + + CHECK(this->check_layout_consistency()); + + set_conf(rnn_, *this->desc(), this->weights_md(0), + this->weights_md(1), this->diff_weights_md(0), + this->diff_weights_md(1)); + + size_t scratchpad_sz{0}, ws_sz{0}; + get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); + + // initialize the workspace if needed + if (rnn_.is_training) { + dims_t ws_dims = { (int)ws_sz }; + mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims, + data_type::u8, format_tag::x); + } + + init_scratchpad(scratchpad_sz); + + return status::success; + } + + rnn_utils::rnn_conf_t rnn_; + + private: + void init_scratchpad(size_t scratchpad_sz) { + using namespace memory_tracking::names; + auto scratchpad = this->scratchpad_registry().registrar(); + scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096); + + int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1; + int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; + scratchpad.book(key_rnn_ptrs_wei_layer, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_wei_iter, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_bia, + sizeof(float *) * ptr_wei_sz); + } + }; + + _ref_rnn_common_t(const pd_t *apd) + : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) { + /// @todo set max_feature_size assuming that we limit the number of + /// iterations and layer to one if slc != dic and sic != dic + /// respectively + + bias_preparation_func = &class_name::bias_prepare; + bias_finalization_func = &class_name::bias_finalize; + + auto set_gemm_funcs + = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) { + if (packed_gemm) { + g = &class_name::packed_gemm; + a = &class_name::assign_packed_weights; + } else { + g = &class_name::gemm; + a = &class_name::assign_weights; + } + }; + set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, + weights_iter_assign_func); + + set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, + weights_layer_assign_func); + + switch (pd()->cell_kind()) { + case alg_kind::vanilla_lstm: + cell_func = &class_name::cell_execution; + if (aprop == prop_kind::forward) { + if (mayiuse(avx512_core)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>( + pd()->rnn_, pd()->attr()); + else if (mayiuse(avx2)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>( + pd()->rnn_, pd()->attr()); + else if (mayiuse(sse42)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>( + pd()->rnn_, pd()->attr()); + assert(rnn_postgemm_ != nullptr); + rnn_postgemm_->init(); + } + elemwise_func = &class_name::lstm_elemwise; + break; + case alg_kind::vanilla_rnn: // @todo switch on cell kind + cell_func = &class_name::cell_execution; + elemwise_func = &class_name::rnn_elemwise; + switch (pd()->activation_kind()) { + case alg_kind::eltwise_relu: + activation_func = &activation<alg_kind::eltwise_relu, aprop>; + break; + case alg_kind::eltwise_tanh: + activation_func = &activation<alg_kind::eltwise_tanh, aprop>; + break; + case alg_kind::eltwise_logistic: + activation_func = &activation<alg_kind::eltwise_logistic, aprop>; + break; + default: break; + } + break; + case alg_kind::vanilla_gru: + cell_func = &class_name::cell_execution_gru; + break; + case alg_kind::gru_linear_before_reset: + cell_func = &class_name::cell_execution_gru_lbr; + elemwise_func = &class_name::gru_lbr_elemwise; + break; + default: break; + } + + grid_computation = &class_name::linear_execution; + + size_t scratchpad_size, workspace_size; + rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_, + ws_c_states_offset_, ws_diff_states_offset_, + ws_grid_comp_offset_, ws_cell_comp_offset_, + ws_bias_offset_, scratchpad_size, workspace_size); + } + + ~_ref_rnn_common_t() {} + + // typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_(ctx); + return status::success; + } + +private: + void execute_(const exec_ctx_t &ctx) const; + rnn_grid_execution_sig(linear_execution); + rnn_cell_execution_sig(cell_execution); + rnn_cell_execution_sig(cell_execution_gru); + rnn_cell_execution_sig(cell_execution_gru_lbr); + rnn_elemwise_sig(rnn_elemwise); + rnn_elemwise_sig(lstm_elemwise); + rnn_elemwise_sig(gru_lbr_elemwise); + rnn_gemm_sig(gemm); + rnn_gemm_sig(packed_gemm); + rnn_bias_prepare_sig(bias_prepare); + rnn_bias_finalize_sig(bias_finalize); + rnn_weights_assign_sig(assign_weights); + rnn_weights_assign_sig(assign_packed_weights); + + float (*activation_func)(float dd, float s, float alpha, float cliping); + + void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, + const src_data_t *xt_, const float *diff_dst_layer) const; + + template <typename input_data_t> + void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter) const; + + template <typename dst_data_t> + void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, + dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const; + + template <typename output_data_t> + void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, + output_data_t *dst_iter_, float *diff_src_iter, + const src_data_t *ws_states_, float *ws_c_states, + const float *ws_diff_states_) const; + + void gates_reduction(const rnn_utils::rnn_conf_t &rnn, + const acc_data_t *ws_gates_, float *diff_bias_) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + size_t ws_gates_offset_; + size_t ws_states_offset_; + size_t ws_c_states_offset_; + size_t ws_bias_offset_; + size_t ws_diff_states_offset_; + size_t ws_grid_comp_offset_; + size_t ws_cell_comp_offset_; + jit_uni_rnn_postgemm_kernel *rnn_postgemm_; + + grid_execution_f grid_computation; + cell_execution_f cell_func; + + bias_prepare_t bias_preparation_func; + bias_finalize_t bias_finalization_func; + weights_assign_t weights_layer_assign_func; + weights_assign_t weights_iter_assign_func; + + gemm_t gemm_layer_func; + gemm_t gemm_iter_func; + elemwise_f elemwise_func; +}; + +using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>; +using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>; +using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>; +} +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp new file mode 100644 index 0000000000..597c63e3f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp @@ -0,0 +1,380 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_RNN_REORDERS_HPP +#define CPU_RNN_REORDERS_HPP + +#include <assert.h> + +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "simple_q10n.hpp" +#include "cpu_reorder_pd.hpp" +#include "../gemm/os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t type_i, data_type_t type_o> +struct rnn_data_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) + && od == id; + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + }; + +private: + typedef typename prec_traits<type_i>::type in_data_t; + typedef typename prec_traits<type_o>::type out_data_t; + + rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const size_t nelems = input_d.nelems(); + const float scale = pd()->attr()->rnn_data_qparams_.scale_; + const float shift = pd()->attr()->rnn_data_qparams_.shift_; + + parallel_nd(nelems, [&](size_t i) { + float in = (float)input[input_d.off_l(i)] * scale + shift; + output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in); + }); + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <data_type_t type_i, data_type_t type_o> +struct rnn_weights_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && od.format_kind() == format_kind::rnn_packed + && od.rnn_packed_desc().format == mkldnn_ldigo_p + && od.rnn_packed_desc().n_parts == 1 + && attr != nullptr; + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + _pd->itag_ = itag; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + format_tag_t itag_; + + private: + void init_scratchpad() { + const memory_desc_wrapper id(src_md()); + const size_t nelems = id.nelems(); + const auto &dims = id.dims(); + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + size_t quantization_size = sizeof(int8_t) * nelems; + size_t reduction_size = itag_ == ldigo + ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] + * dims[1] * dims[3] * dims[4] + : 0; + scratchpad.book( + key_reorder_rnn_weights_quantization, quantization_size); + scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); + } + }; + +private: + typedef typename prec_traits<type_i>::type in_data_t; + typedef typename prec_traits<type_o>::type out_data_t; + + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + + /* Quantize input & compute compensation */ + auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_rnn_weights_quantization); + auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_rnn_weights_reduction); + float *comp = reinterpret_cast<float *>( + output + output_d.rnn_packed_desc().offset_compensation); + const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; + const int mask = pd()->attr()->rnn_weights_qparams_.mask_; + + if (is_igo) { + int nthr = mkldnn_get_max_threads(); + int LD_nthr = nstl::min(L * D, nthr); + int I_nthr = nstl::min(I, nthr / LD_nthr); + parallel(nthr, [&](const int ithr, const int nthr) { + int LD_ithr = -1, LD_s = -1, LD_e = -1; + int I_ithr = -1, I_s = -1, I_e = -1; + if (ithr < LD_nthr * I_nthr) { + LD_ithr = ithr % LD_nthr; + I_ithr = ithr / LD_nthr; + balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); + balance211(I, I_nthr, I_ithr, I_s, I_e); + } + int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; + for (int ld = LD_s; ld < LD_e; ld++) { + for (int go = 0; go < G * O; go++) + comp_ithr[ld * G * O + go] = 0; + for (int i = I_s; i < I_e; i++) { + PRAGMA_OMP_SIMD() + for (int go = 0; go < G * O; go++) { + const float s = scales[(mask == 0) ? 0 : go]; + int8_t q = qz_b0<in_data_t, out_data_t>()( + input[ld * I * G * O + i * G * O + go], s); + quantized[ld * I * G * O + i * G * O + go] + = (int32_t)q; + comp_ithr[ld * G * O + go] += (int32_t)q; + } + } + } + }); + parallel_nd(L * D * G * O, + [&](int s) { comp[s] = saturate<float>(reduction[s]); }); + for (int i = 1; i < I_nthr; i++) { + parallel_nd(L * D * G * O, [&](int s) { + comp[s] += saturate<float>( + reduction[i * L * D * G * O + s]); + }); + } + } else { + parallel_nd(L * D, G * O, [&](int ld, int go) { + int32_t compensation = 0; + const float s = scales[(mask == 0) ? 0 : go]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < I; i++) { + int8_t q = qz_b0<in_data_t, out_data_t>()( + input[ld * G * O * I + go * I + i], s); + compensation += (int32_t)q; + quantized[ld * G * O * I + go * I + i] = q; + } + comp[ld * G * O + go] = saturate<float>(compensation); + }); + } + + /* Pack */ + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + int n_parts = output_d.rnn_packed_desc().n_parts; + const size_t *size_packed_cell + = output_d.rnn_packed_desc().part_pack_size; + const int *parts = output_d.rnn_packed_desc().parts; + const int n = output_d.rnn_packed_desc().n; + char *to_pack = output; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = parts[p] * O; + int k_p = I; + cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, + is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, + &quantized[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, g, 0, 0)], + is_igo ? G * O : I, to_pack); + to_pack += size_packed_cell[p]; + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <> +struct rnn_weights_reorder_t<data_type::f32, data_type::f32> + : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == data_type::f32 + && od.data_type() == data_type::f32 + && od.format_kind() == format_kind::rnn_packed + && utils::one_of(od.rnn_packed_desc().format, + mkldnn_ldigo_p, mkldnn_ldgoi_p) + && attr->has_default_values(); + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->itag_ = itag; + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + format_tag_t itag_; + }; + +private: + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + /* Pack */ + bool cross_case = false + || (pd()->itag_ == format_tag::ldigo + && rnn_pdata.format == mkldnn_ldgoi_p) + || (pd()->itag_ == format_tag::ldgoi + && rnn_pdata.format == mkldnn_ldigo_p); + auto trans = cross_case ? CblasTrans : CblasNoTrans; + int n_parts = rnn_pdata.n_parts; + const size_t *size_packed_cell = rnn_pdata.part_pack_size; + const int *parts = rnn_pdata.parts; + const int n = rnn_pdata.n; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = is_igo ? parts[p] * O : I; + int k_p = is_igo ? I : parts[p] * O; + int ld = is_igo ? G * O : I; + cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, + k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, 0, g, 0)], + ld, output); + output += size_packed_cell[p] / sizeof(float); + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp new file mode 100644 index 0000000000..1d60415cbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "rnn_utils.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace rnn_utils; +using namespace format_tag; +using namespace rnn_packed_format; +using namespace data_type; + +bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1 + && str[3] == dims[4] && str[1] == str[2] * dims[2] + && str[0] == str[1] * dims[1]; +}; + +bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1 + && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3] + && str[0] == str[1] * dims[1]; +}; + +void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d) { + rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + rnn.is_training = utils::one_of( + rd.prop_kind, prop_kind::forward_training, prop_kind::backward); + rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset; + + switch (rd.direction) { + case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break; + case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break; + case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break; + case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break; + default: break; + } + + if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(), + weights_layer_d.data_type())) + rnn.dt_conf = all_f32; + else if (dst_layer_d.data_type() == u8) { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8u8; + else + rnn.dt_conf = f32u8f32u8; + } else { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8f32; + else + rnn.dt_conf = f32u8f32f32; + } + + rnn.n_layer = weights_layer_d.dims()[0]; + rnn.n_iter = src_layer_d.dims()[0]; + rnn.n_dir = weights_layer_d.dims()[1]; + rnn.n_gates = weights_layer_d.dims()[3]; + rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc); + rnn.n_bias = rnn.n_gates + rnn.is_lbr; + rnn.mb = src_layer_d.dims()[1]; + rnn.sic = weights_iter_d.dims()[2]; + rnn.slc = weights_layer_d.dims()[2]; + rnn.dic = weights_layer_d.dims()[4]; + rnn.dlc = dst_layer_d.dims()[2]; + + rnn.gates_ld = rnn.dic * rnn.n_gates; + rnn.gates_nld = rnn.mb; + rnn.states_nld = rnn.mb; + + /* Set the correct number of weights parts */ + bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru; + rnn.n_parts_weights_layer = 1; + rnn.parts_weights_layer[0] = rnn.n_gates; + rnn.parts_weights_layer[1] = 0; + + rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1; + rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates; + rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0; + + rnn.n_parts_bias = 1; + rnn.parts_bias[0] = rnn.n_bias; + rnn.parts_bias[1] = 0; + + /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas + * and if to mergre gemm across iterations */ + bool is_int8 = rnn.dt_conf != all_f32; + rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) + || is_int8; + bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset); + rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8; + bool is_inference = !rnn.is_training; + + rnn.use_jit_gemm = !mayiuse(avx512_mic) + && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) + || (rnn.is_training && rnn.dic < 500)); + + /* Decide to copy bias */ + rnn.copy_bias = rnn.dt_conf != all_f32; + +#if USE_MKL_PACKED_GEMM + rnn.use_layer_packed_gemm + = (weights_layer_d.format_kind() == format_kind::any + && rnn.slc > 760 && rnn.dic > 760 && is_inference) + || is_int8; // packed gemm is the only supported option for int8 + rnn.use_iter_packed_gemm + = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760 + && rnn.dic > 760 && is_inference) + || is_int8; +#else + rnn.use_layer_packed_gemm = false; + rnn.use_iter_packed_gemm = false; +#endif + + /* Set packed gemm sizes */ + if (rnn.use_layer_packed_gemm) { + rnn.weights_layer_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_layer; p++) { + int m_p = rnn.is_fwd + ? (rnn.parts_weights_layer[p] * rnn.dic) + : rnn.slc; + int k_p = rnn.is_fwd + ? rnn.slc + : (rnn.parts_weights_layer[p] * rnn.dic); + int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_layer_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_layer_pack_size[p] = 0; +#endif + rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_layer_pack_size[p]; + } + rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size; + rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float); + } + + if (rnn.use_iter_packed_gemm) { + rnn.weights_iter_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_iter; p++) { + int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) : + rnn.sic; + int k_p = rnn.is_fwd ? rnn.sic : + (rnn.parts_weights_iter[p] * rnn.dic); + int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_iter_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_iter_pack_size[p] = 0; +#endif + rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_iter_pack_size[p]; + } + rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size; + rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float); + } + +} + +void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d) { + + /* Set leading dimensions for input weights arrays depending on input format + */ + rnn.weights_layer_is_packed + = weights_layer_d.format_kind() == format_kind::rnn_packed; + rnn.weights_iter_is_packed + = weights_iter_d.format_kind() == format_kind::rnn_packed; + + auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) { + ld = 0; nld = 0; + if (md.is_blocking_desc()) { + if (is_ldigo(md)) { + ld = (int)md.blocking_desc().strides[2]; + nld = md.dims()[2]; + } else if (is_ldgoi(md)) { + ld = (int)md.blocking_desc().strides[4]; + nld = md.dims()[3] * md.dims()[4]; + } else + assert(!"unsupported weights format"); + } + }; + set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); + set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); + if (!rnn.is_fwd) { + set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, + rnn.diff_weights_layer_nld); + set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, + rnn.diff_weights_iter_nld); + } + + int sizeof_states_dt + = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t); + rnn.states_ws_ld + = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)), + sizeof_states_dt); + rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float)); + + /* Set workspace sizes to store: + * states to copmute a pass + * diff states to copmute bwd pass (training only) + * intermediate results from the gates + */ + rnn.use_workspace = rnn.is_training; + rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir + * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt; + bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm; + rnn.ws_c_states_size = is_lstm + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb + * rnn.states_ws_ld * sizeof(float) + : 0; + rnn.ws_diff_states_size = rnn.is_training + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) + * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld + * sizeof(float) + : (size_t)0; + rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb + * rnn.gates_ws_ld * sizeof(float); + + /* set other sizes */ + rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float); + rnn.ws_cell_comp_size + = rnn.is_lbr || rnn.dt_conf != all_f32 + ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float) + : 0; + rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer + * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); + rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic + * sizeof(float); +} + +int rnn_utils::get_good_ld(int dim, int sizeof_dt) { + // we want matrices leading dimentions to be 64-byte aligned, + // and not divisible by 256 to avoid 4K aliasing effects + int ld = rnd_up(dim, 64 / sizeof_dt); + return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; +} + +void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_states_offset, size_t &ws_c_states_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size) { + + const size_t page_size = 4096; // 2097152; + size_t current_offset; + /* Mandatory workspaces: go to workspace if use_workspace, scratchpad + * otherwise */ + current_offset = 0; // assumes the workspace base pointer is page aligned + ws_gates_offset = current_offset; + current_offset += rnn.ws_gates_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_states_offset = current_offset; + current_offset += rnn.ws_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_c_states_offset = current_offset; + current_offset += rnn.ws_c_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_diff_states_offset = current_offset; + current_offset += rnn.ws_diff_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_grid_comp_offset = current_offset; + current_offset += rnn.ws_grid_comp_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_cell_comp_offset = current_offset; + current_offset += rnn.ws_cell_comp_size; + + workspace_size = rnn.use_workspace ? current_offset : 0; + + /* Optional scratchpads */ + // Assumes the scratchpad base pointer is page aligned. + // If use_workspace, the following goes to scratchpad alone, + // otherwise, all goes to scratchpad and continue incrementing offset + current_offset = rnn.use_workspace ? 0 : current_offset; + + if (rnn.copy_bias) { + current_offset = utils::rnd_up(current_offset, page_size); + ws_bias_offset = current_offset; + current_offset += rnn.ws_bias_size; + } + + scratchpad_size = current_offset; +} + +void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size) { + size_t ws_gates_offset, ws_states_offset, ws_c_states_offset, + ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset; + set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset, + ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset, scratchpad_size, workspace_size); +} + +status_t rnn_utils::set_good_strides( + memory_desc_t &weights_md, format_tag_t tag) { + auto &strides = weights_md.format_desc.blocking.strides; + auto dims = weights_md.dims; + + if (tag == ldigo) { + strides[2] = rnn_utils::get_good_ld((int)strides[2], + (int)types::data_type_size(weights_md.data_type)); + strides[1] = dims[2] * strides[2]; + strides[0] = dims[1] * strides[1]; + } else if (tag == ldgoi) { + strides[4] = rnn_utils::get_good_ld((int)strides[4], + (int)types::data_type_size(weights_md.data_type)); + strides[3] = dims[4] * strides[4]; + strides[1] = dims[3] * strides[3]; + strides[0] = dims[1] * strides[1]; + } else + return status::unimplemented; + + return status::success; +} + +status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, + memory_desc_t &weights_md, bool is_iter) { + using namespace format_tag; + bool use_packed_gemm = is_iter + ? rnn.use_iter_packed_gemm + : rnn.use_layer_packed_gemm; + if (use_packed_gemm) { + weights_md.format_kind = format_kind::rnn_packed; + rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; + rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p; + if (is_iter) { + rnn_pdata.n = rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_iter; + array_copy(rnn_pdata.parts, rnn.parts_weights_iter, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; + rnn_pdata.size = rnn.weights_iter_pack_size; + } else { + rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_layer; + array_copy(rnn_pdata.parts, rnn.parts_weights_layer, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; + rnn_pdata.size = rnn.weights_layer_pack_size; + } + } else { + CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + // Adjust strides for good leading dimension in GEMM + CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + } + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp new file mode 100644 index 0000000000..99eb787a64 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_UTILS_HPP +#define RNN_UTILS_HPP + +#include "mkldnn.h" + +#include "cpu_rnn_pd.hpp" + + +#define rnn_elemwise_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \ + src_data_t *states_t_l_, float *c_states_t_l_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_l_, float *diff_states_t_lp1_, \ + float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \ + float *ws_cell_) const + +#define rnn_cell_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \ + float *c_states_t_l_, float *diff_states_t_l_, \ + weights_data_t **w_layer_, weights_data_t **w_iter_, \ + float **bias_, src_data_t *states_t_lm1_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ + float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \ + acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const + +#define rnn_grid_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \ + weights_data_t **weights_states_, float **bias_, \ + src_data_t *ws_states_, float *ws_c_states_, \ + float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \ + float *ws_grid_, float *diff_weights_layer_, \ + float *diff_weights_iter_, float *diff_bias_) const + +#define rnn_gemm_sig(f) \ + void f(const char transA, const char transB, int m, int n, int k, \ + const float alpha, const weights_data_t *a_, const int ldA, \ + const src_data_t *b_, const int ldB, const float beta, \ + acc_data_t *c_, const int ldC) const + +#define rnn_bias_prepare_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \ + float *scratch_bias_) const + +#define rnn_bias_finalize_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \ + const float *w_iter_comp, const float *w_layer_comp) const + +#define rnn_weights_assign_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \ + int ld, int OC_size, int IC_size, const int n_parts, \ + const int *gates_per_part, const size_t *part_weights_pack_size, \ + weights_data_t **weights_, const weights_data_t *w_, \ + float **bias_, const float *b_, float *scratch_bias_) const + + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace rnn_utils { + +using namespace mkldnn::impl::utils; + +enum execution_direction_t { + l2r, + r2l, + bi_concat, + bi_sum, +}; + +enum data_type_conf_t { + all_f32, + u8u8u8f32, + f32u8f32f32, + u8u8u8u8, + f32u8f32u8 +}; + +struct rnn_conf_t { + execution_direction_t exec_dir; + data_type_conf_t dt_conf; + int n_layer, n_iter, n_dir, n_gates, n_states; + int mb; + int slc, sic, dic, dlc; + int gates_ld, gates_nld, gates_ws_ld; + int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS]; + int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS]; + int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS], + part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + bool weights_layer_is_packed, weights_iter_is_packed; + /* Size of packed data in bytes */ + size_t weights_layer_comp_offset, weights_layer_pack_size, + weights_iter_comp_offset, weights_iter_pack_size; + + bool copy_bias; + int weights_layer_ld, weights_layer_nld; + int diff_weights_layer_ld, diff_weights_layer_nld; + int weights_iter_ld, weights_iter_nld; + int diff_weights_iter_ld, diff_weights_iter_nld; + int states_nld, states_ws_ld; + int weights_iter_compensation_size, weights_layer_compensation_size; + bool is_fwd, is_training, is_lbr; + bool use_workspace; + + /* Size of workspace for each tensor in bytes */ + size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size, + ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size; + bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm, + use_iter_packed_gemm; +}; + +bool is_ldigo(const memory_desc_wrapper &md); +bool is_ldgoi(const memory_desc_wrapper &md); + +int get_good_ld(int dim, int sizeof_dt); + +void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d); + +void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d); + +void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_h_state_offset, size_t &ws_c_state_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size); + +void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size); +status_t set_expected_desc( + rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter); +status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); + +template <typename T> +struct ws_gates_aoc { + ws_gates_aoc(const rnn_conf_t &rnn, T *data) + : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {} + T &operator()(int batch, int gate, int dic) { + return gates_(batch, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator<T, 2> gates_; + int DIC_; +}; +using ws_gates_aoc_t = ws_gates_aoc<float>; +using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>; + +struct bias_aoc_t { + bias_aoc_t(const rnn_conf_t &rnn, const float *data) + : bias_(data, rnn.n_bias, rnn.dic) {} + const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator<const float, 2> bias_; +}; + +template <typename T> +struct ws_states_aoc { + ws_states_aoc(const rnn_conf_t &rnn, T *data) + : state_(data, rnn.states_nld, rnn.states_ws_ld) {} + T &operator()(int batch, int dic) { return state_(batch, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator<T, 2> state_; +}; +using ws_states_aoc_t = ws_states_aoc<float>; +using ws_states_aoc_u8_t = ws_states_aoc<uint8_t>; + +struct ws_diff_states_aoc_t { + ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld, + rnn.states_ws_ld) {} + float &operator()(int state_n, int batch, int dic) { + return diff_states_(state_n, 0, batch, dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator<float, 4> diff_states_; +}; + +struct ws_diff_w_iter_aoc_t { + ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_weights_iter_( + data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) + , DIC_(rnn.dic) {} + float &operator()(int sic, int gate, int dic) { + return diff_weights_iter_(sic, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator<float, 2> diff_weights_iter_; + int DIC_; +}; +} +} +} +} +#endif |