/******************************************************************************* * Copyright 2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ /* General architecture for diff states, we have n_states + 1 as we have n_states diff to propagate to the previous iteration and 1 states to propagate to the previous layer index 0 is dh for cell(t-1, l) to consume index 1 is dc for cell(t-1, l) to consume index 2 is dh for cell(t, l-1) to consume this indexing enables to have the same indexing for states in elemwise function only the cell execution function should be impacted */ #include "math_utils.hpp" #include "mkldnn_thread.hpp" #include "ref_rnn.hpp" #include "../gemm/gemm.hpp" #include "../simple_q10n.hpp" namespace mkldnn { namespace impl { namespace cpu { using namespace mkldnn::impl::utils; using namespace mkldnn::impl::memory_tracking::names; using namespace rnn_utils; #define AOC array_offset_calculator template void _ref_rnn_common_t::gates_reduction( const rnn_conf_t &rnn, const acc_data_t *ws_gates_, float *diff_bias_) const { auto body = [&](int i, int k) { for (int j = 0; j < rnn.mb; j++) diff_bias_[i * rnn.dic + k] += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; }; // @todo block k on simd-width #if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ /* icc 17.0 has a problem with simd collapse */ \ && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) #pragma omp parallel for simd collapse(2) for (int i = 0; i < rnn.n_gates; i++) for (int k = 0; k < rnn.dic; k++) body(i, k); #else parallel_nd(rnn.n_gates, rnn.dic, body); #endif } template rnn_gemm_sig((_ref_rnn_common_t::gemm)) { assert(ldA * ldB * ldC != 0); extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); } template <> rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { assert(!"non packed gemm is disabled for int8"); } template rnn_gemm_sig((_ref_rnn_common_t::packed_gemm)) { #if (USE_MKL_PACKED_GEMM) assert(transA == 'N'); cblas_sgemm_compute(CblasColMajor, CblasPacked, (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, ldB, beta, c_, ldC); #else UNUSED(transA); UNUSED(transB); UNUSED(m); UNUSED(n); UNUSED(k); UNUSED(alpha); UNUSED(ldA); UNUSED(b_); UNUSED(ldB); UNUSED(beta); UNUSED(c_); UNUSED(ldC); assert(!"packed gemm is disabled"); #endif } template <> rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { #if (USE_MKL_PACKED_GEMM) int8_t offseta = 0, offsetb = 0; int32_t offsetc = 0; cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, ldB, offsetb, beta, c_, ldC, &offsetc); #else UNUSED(transA); UNUSED(transB); UNUSED(m); UNUSED(n); UNUSED(k); UNUSED(alpha); UNUSED(ldA); UNUSED(b_); UNUSED(ldB); UNUSED(beta); UNUSED(c_); UNUSED(ldC); assert(!"packed gemm is disabled"); #endif } //*************** Grid computations strategy: linear ***************// template rnn_grid_execution_sig( (_ref_rnn_common_t::linear_execution)) { AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, (rnn.n_states + 1), rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); AOC ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, rnn.gates_nld * rnn.gates_ws_ld); AOC weights_input( weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); AOC weights_states( weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); AOC bias( bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); AOC diff_weights_layer(diff_weights_layer_, rnn.n_layer, rnn.n_dir, rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); AOC diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); AOC diff_bias( diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); AOC ws_grid( ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); // We run the grid of computation for (int dir = 0; dir < rnn.n_dir; dir++) { for (int j = 0; j < rnn.n_layer; j++) { int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb * rnn.n_iter, rnn.slc, 1.0, weights_input(lay, dir, 0), rnn.weights_iter_ld, &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); } for (int i = 0; i < rnn.n_iter; i++) { int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; (this->*cell_func)(rnn, &(ws_states(lay + 1, dir, iter + 1, 0)), &(ws_c_states(lay + 1, dir, iter + 1, 0)), &(ws_diff_states(lay, dir, 0, iter, 0)), &(weights_input(lay, dir, 0)), &(weights_states(lay, dir, 0)), &(bias(lay, dir, 0)), &(ws_states(lay, dir, iter + 1, 0)), &(ws_states(lay + 1, dir, iter, 0)), &(ws_c_states(lay + 1, dir, iter, 0)), &(ws_diff_states(lay + 1, dir, 0, iter, 0)), &(ws_diff_states(lay, dir, 0, iter + 1, 0)), &(diff_weights_layer(lay, dir, 0)), &(diff_weights_iter(lay, dir, 0)), &(diff_bias(lay, dir, 0)), &(ws_gates(lay, dir, iter, 0)), &(ws_grid(lay, dir, iter, 0)), ws_cell_); } if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), rnn.weights_layer_ld, (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), rnn.gates_ws_ld, 0.0, (acc_data_t *)(&(ws_diff_states( lay, dir, rnn.n_states, 0, 0))), rnn.states_ws_ld); gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb * rnn.n_iter, 1.0, (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), rnn.gates_ws_ld, (src_data_t *)(&(ws_states(lay, dir, 1, 0))), rnn.states_ws_ld, 1.0, (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), rnn.diff_weights_layer_ld); } if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb * rnn.n_iter, 1.0, (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), rnn.gates_ws_ld, (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), rnn.states_ws_ld, 1.0, (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), rnn.diff_weights_iter_ld); } } } } //********* GRID computations strategy: utility functions **********// template void _ref_rnn_common_t::copy_init_layer( const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, const float *__restrict diff_dst_layer_) const { AOC ws_states( ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); auto xt_d = memory_desc_wrapper(pd()->src_md(0)); parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { auto xxt = xt_ + xt_d.blk_off(it, b); src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); if (rnn.exec_dir != r2l) for (int c = 0; c < rnn.slc; c++) ws_l2r_ptr[c] = xxt[c]; if (rnn.exec_dir != l2r) for (int c = 0; c < rnn.slc; c++) ws_r2l_ptr[c] = xxt[c]; }); } template <> void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, const float *diff_dst_layer_) const { AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); switch (rnn.exec_dir) { case bi_concat: parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { auto diff_dst_layer_x = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); for (int s = 0; s < rnn.dic; s++) { ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) = diff_dst_layer_x[s]; ws_diff_states( rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) = diff_dst_layer_x[rnn.dic + s]; } }); break; case bi_sum: parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { auto diff_dst_layer_x = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); for (int s = 0; s < rnn.dic; s++) { ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) = diff_dst_layer_x[s]; ws_diff_states( rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) = diff_dst_layer_x[s]; } }); break; case l2r: parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { auto diff_dst_layer_x = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); for (int s = 0; s < rnn.dic; s++) { ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) = diff_dst_layer_x[s]; } }); break; case r2l: parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { auto diff_dst_layer_x = diff_dst_layer_ + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); for (int s = 0; s < rnn.dic; s++) { ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) = diff_dst_layer_x[s]; } }); break; default: assert(!"Unsupported direction"); break; } } /* For int8 configuration, input iteration states may be of types f32 or u8 * Internally h_state is always stored in u8 and c_state is always stored in f32 * If input states are of type u8 then h state is copied and c state is dequantized * If input states are of type f32 then h state is quantized and c_state is copied * */ template template void _ref_rnn_common_t::copy_init_iter( const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, float *__restrict ws_c_states_, float *__restrict ws_diff_states_, const input_data_t *__restrict firstit_states_, const float *__restrict diff_dst_iter_) const { AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); float data_shift = pd()->attr()->rnn_data_qparams_.shift_; float data_scale = pd()->attr()->rnn_data_qparams_.scale_; const bool quantize = pd()->with_src_iter() && pd()->src_md(1)->data_type == data_type::f32 && rnn.dt_conf != all_f32; auto maybe_q = [&](input_data_t f) { if (quantize) { float qf = f * data_scale + data_shift; return qz_a1b0()(qf); } else return (src_data_t)f; }; const bool dequantize = pd()->with_src_iter() && pd()->src_md(1)->data_type == data_type::u8; auto maybe_deq = [&](input_data_t s) { if (dequantize) return (((float)s - data_shift) / data_scale); else return (float)s; }; auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); if (firstit_states_) { parallel_nd( rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { for (int s = 0; s < rnn.sic; s++) ws_states(lay + 1, dir, 0, b, s) = maybe_q( firstit_states_[firstit_states_d.blk_off( lay, dir, 0, b, s)]); if (pd()->cell_kind() == alg_kind::vanilla_lstm) for (int s = 0; s < rnn.sic; s++) ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( firstit_states_[firstit_states_d.blk_off( lay, dir, 1, b, s)]); }); } else { parallel_nd( rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { for (int j = 0; j < rnn.sic; j++) { ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; } }); } } template <> template void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, const input_data_t *firstit_states_, const float *diff_dst_iter_) const { AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); if (diff_dst_iter_) { parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, [&](int lay, int dir, int state, int b) { array_copy(&(ws_diff_states( lay, dir, state, rnn.n_iter, b, 0)), diff_dst_iter_ + diff_dst_iter_d.blk_off( lay, dir, state, b), rnn.dic); }); } else { parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, [&](int lay, int dir, int state, int i) { for (int j = 0; j < rnn.dic; j++) ws_diff_states(lay, dir, state, rnn.n_iter, i, j) = 0.0f; }); } } template template void _ref_rnn_common_t::copy_res_layer( const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, const src_data_t *ws_states_, const float *ws_diff_states_) const { auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); float shift = (pd()->attr()->rnn_data_qparams_.shift_); float scale = (pd()->attr()->rnn_data_qparams_.scale_); const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 && rnn.dt_conf != all_f32; auto maybe_deq = [&](src_data_t s) { if (dequantize) return (dst_data_t)(((float)s - shift) / scale); else return (dst_data_t)s; }; parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { int dir = 0; if (rnn.exec_dir != r2l) { for (int s = 0; s < rnn.dic; s++) { dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); } dir = 1; } if (rnn.exec_dir != l2r) { for (int s = 0; s < rnn.dic; s++) switch (rnn.exec_dir) { case bi_sum: dst_layer_[dst_layer_d.blk_off(it, b, s)] += maybe_deq(ws_states( rnn.n_layer, dir, rnn.n_iter - it, b, s)); break; default: dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] = maybe_deq(ws_states( rnn.n_layer, dir, rnn.n_iter - it, b, s)); } } }); } template <> template void ref_rnn_bwd_f32_t::copy_res_layer( const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, const src_data_t *ws_states_, const float *ws_diff_states_) const { auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { int dir = 0; for (int s = 0; s < rnn.slc; s++) { float *dst_addr = diff_src_layer_ + diff_src_layer_d.blk_off( (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, b, dir * rnn.slc + s); float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); if (rnn.n_dir - 1) res += ws_diff_states( 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); dst_addr[0] = res; } }); } template template void _ref_rnn_common_t::copy_res_iter( const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, const src_data_t *ws_states_, float *ws_c_states_, const float *ws_diff_states_) const { auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); float data_shift = pd()->attr()->rnn_data_qparams_.shift_; float data_scale = pd()->attr()->rnn_data_qparams_.scale_; const bool quantize = pd()->with_dst_iter() && pd()->dst_md(1)->data_type == data_type::u8 && rnn.dt_conf != all_f32; auto maybe_q = [&](float f) { if (quantize) { float qf = f * data_scale + data_shift; return qz_a1b0()(qf); } else return (output_data_t)f; }; const bool dequantize = pd()->with_dst_iter() && pd()->dst_md(1)->data_type == data_type::f32 && rnn.dt_conf != all_f32; auto maybe_deq = [&](src_data_t s) { if (dequantize) return (output_data_t)(((float)s - data_shift) / data_scale); else return (output_data_t)s; }; if (dst_iter_) { parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { for (int s = 0; s < rnn.dic; s++) { dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); } if (pd()->cell_kind() == alg_kind::vanilla_lstm) for (int s = 0; s < rnn.dic; s++) { dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] = maybe_q(ws_c_states( lay + 1, dir, rnn.n_iter, b, s)); } }); } } template <> template void ref_rnn_bwd_f32_t::copy_res_iter( const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, const src_data_t *ws_states_, float *ws_c_states_, const float *ws_diff_states_) const { auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); if (diff_src_iter_) { parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, [&](int lay, int dir, int state, int b) { for (int s = 0; s < rnn.sic; s++) { diff_src_iter_[diff_src_iter_d.blk_off( lay, dir, state, b, s)] = ws_diff_states(lay, dir, state, 0, b, s); } }); } } template rnn_bias_prepare_sig((_ref_rnn_common_t::bias_prepare)) { /* Original set of bias provided by the user */ AOC b( b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); /* Array of pointers initialized in packing */ AOC bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); AOC scratch_bias( scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); if (rnn.copy_bias) { parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, [&](size_t i) { scratch_bias_[i] = b_[i]; }); } for (int i = 0; i < rnn.n_layer; i++) { for (int d = 0; d < rnn.n_dir; d++) { int offset_bias = 0; for (int p = 0; p < rnn.n_parts_bias; p++) { bias(i, d, p) = rnn.copy_bias ? (float *) &scratch_bias(i, d, offset_bias) : (float *) &b(i, d, offset_bias); offset_bias += rnn.parts_bias[p] * rnn.dic; } } } } template rnn_bias_finalize_sig( (_ref_rnn_common_t::bias_finalize)) { if (rnn.dt_conf != all_f32) { float data_shift = pd()->attr()->rnn_data_qparams_.shift_; float data_scale = pd()->attr()->rnn_data_qparams_.scale_; float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { size_t off = i * rnn.n_bias * rnn.dic + j; float weights_scale = scale_per_oc ? weights_scales[j] : weights_scales[0]; scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) * data_shift / (weights_scale * data_scale); } } } template rnn_weights_assign_sig((_ref_rnn_common_t::assign_packed_weights)) { assert(md->format_kind == format_kind::rnn_packed); const auto packed_desc = md->format_desc.rnn_packed_desc; AOC weights(weights_, rnn.n_layer, rnn.n_dir, packed_desc.n_parts); size_t offset_packed = 0; for (int l = 0; l < rnn.n_layer; l++) for (int d = 0; d < rnn.n_dir; d++) { for (int p = 0; p < packed_desc.n_parts; p++) { weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; offset_packed += packed_desc.part_pack_size[p] / sizeof(weights_data_t); } } } template rnn_weights_assign_sig( (_ref_rnn_common_t::assign_weights)) { assert(md->format_kind == format_kind::blocked); const auto &blk = md->format_desc.blocking; /* Original set of weights provided by the user */ AOC w(w_, rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); /* Array of pointers for each part of weights */ AOC weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); for (int i = 0; i < rnn.n_layer; i++) for (int d = 0; d < rnn.n_dir; d++) { size_t offset_weights = 0; for (int p = 0; p < n_parts; p++) { weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); offset_weights += gates_per_part[p] * blk.strides[3]; } } } //********************* Execution function *********************// template void _ref_rnn_common_t::execute_( const exec_ctx_t &ctx) const { const rnn_conf_t &rnn = this->pd()->rnn_; auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); auto dst_last_layer = rnn.is_fwd ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); auto dst_last_iter = rnn.is_fwd ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); auto w_layer = reinterpret_cast(layer_weights_n_comp); auto w_iter = reinterpret_cast(iter_weights_n_comp); auto w_iter_comp = reinterpret_cast( iter_weights_n_comp + rnn.weights_iter_comp_offset); auto w_layer_comp = reinterpret_cast( layer_weights_n_comp + rnn.weights_layer_comp_offset); auto scratchpad = this->scratchpad(ctx); auto ptr_wei_layer = scratchpad.template get(key_rnn_ptrs_wei_layer); auto ptr_wei_iter = scratchpad.template get(key_rnn_ptrs_wei_iter); auto ptr_bias = scratchpad.template get(key_rnn_ptrs_bia); // fetchihg buffers from the workspace // if no workspace was provided we use the scratchpad char *scratch_ptr = scratchpad.template get(key_rnn_space); char *ws_ptr = nullptr; if (rnn.use_workspace) ws_ptr = rnn.is_fwd ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); // Fetching extra buffers from scratchpad float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); // initialize diff_states to 0 if (aprop == prop_kind::backward) array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); /* Pack(if using packed gemm API) or copy(if input arrays have bad leading * dimension */ (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, ptr_bias, bias, ws_bias); (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, rnn.n_parts_weights_layer, rnn.parts_weights_layer, rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, bias, ws_bias); (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); // we first need to copy the initial states and input into ws copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 || rnn.dt_conf == all_f32) copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, (const float *)states, diff_dst_iter); else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, (const uint8_t *)states, diff_dst_iter); else assert(!"unimplemented"); // run the execution on the grid (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, diff_weights_layer, diff_weights_iter, diff_bias); // Finally we copy the results to the result buffers if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 || rnn.dt_conf == all_f32) copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, ws_diff_states); else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, ws_states, ws_diff_states); else assert(!"unimplemented"); if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 || rnn.dt_conf == all_f32) copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, ws_c_states, ws_diff_states); else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, ws_c_states, ws_diff_states); else assert(!"unimplemented"); }; /* Fix for MSVS warning C4661 */ template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); template struct _ref_rnn_common_t; template struct _ref_rnn_common_t; template struct _ref_rnn_common_t; #undef AOC } } }