diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp | 376 |
1 files changed, 376 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp new file mode 100644 index 0000000000..c2082d7d62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp @@ -0,0 +1,376 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_WINO_REORDER_HPP +#define CPU_WINO_REORDER_HPP + +#include "mkldnn_thread.hpp" + +#include "simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t type_i, data_type_t type_o> +struct wino_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_tag(utils::pick(id.ndims() - 4, + format_tag::oihw, format_tag::goihw)) + && od.format_kind() == format_kind::wino + && utils::one_of(od.wino_desc().wino_format, + mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + auto &o = memory_desc_wrapper(dst_md()).wino_desc(); + size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; + size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_reorder_wino_transform_space, + sizeof(in_data_t) * transform_space_size); + scratchpad.book(key_reorder_wino_plain, + sizeof(out_data_t) * plain_size); + } + }; + +private: + typedef typename prec_traits<type_i>::type in_data_t; + typedef typename prec_traits<type_o>::type out_data_t; + const int unsign_val_in_wino_domain_ = 5; + + wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + r_ = dst_d.wino_desc().r; + w_alpha_ = dst_d.wino_desc().alpha; + wino_format_ = dst_d.wino_desc().wino_format; + + const auto &in_dims = src_d.dims(); + int groups; + int groups_offset; + if (src_d.ndims() == 5) { + groups = in_dims[0]; + groups_offset = 1; + } else { + groups = 1; + groups_offset = 0; + } + assert(groups == 1); // groups are not supported now + MAYBE_UNUSED(groups); + + or_oc_ = in_dims[0 + groups_offset]; + or_ic_ = in_dims[1 + groups_offset]; + kh_ = in_dims[2 + groups_offset]; + kw_ = in_dims[3 + groups_offset]; + + oc_ = dst_d.wino_desc().oc; + ic_ = dst_d.wino_desc().ic; + oc_block_ = dst_d.wino_desc().oc_block; + ic_block_ = dst_d.wino_desc().ic_block; + assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); + nb_oc_ = oc_ / oc_block_; + nb_ic_ = ic_ / ic_block_; + ic2_block_ = 1; + if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + ic2_block_ = dst_d.wino_desc().ic2_block; + oc2_block_ = dst_d.wino_desc().oc2_block; + assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); + + adj_scale_ = dst_d.wino_desc().adj_scale; + + size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; + size_wspace_ = r_ * w_alpha_ * oc_block_; + } + + void transform(out_data_t *__restrict tmp_wei, + const in_data_t *__restrict input, + in_data_t *__restrict wspace) const { + const memory_desc_wrapper src_d(pd()->src_md()); + + const int smask = pd()->attr()->output_scales_.mask_; + const int ndims_mask = math::ilog2q(smask + 1); + const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); + const float *__restrict scales = pd()->attr()->output_scales_.scales_; + assert(D_mask == 1 || D_mask == (size_t)oc_); + + /* transform weights to winograd domain */ + const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, + { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; + + const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, + { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, + { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, + { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, + { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, + { 0.f, 0.f, 1.f } }; + + float *__restrict g; + if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) + g = (float *)G_2x2_3x3; + else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + g = (float *)G_4x4_3x3; + else { + assert("Unknown winograd weights target layout"); + return; + } + + int Z = oc_ * ic_; + assert(r_ == kh_ && r_ == kw_); + + for (int iic = 0; iic < ic_; iic++) { + for (int ob = 0; ob < nb_oc_; ob++) { + const in_data_t *__restrict _inp + = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; + out_data_t *__restrict _out + = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; + + for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); + + for_nd(0, 1, r_, w_alpha_, oc_block_, + [&](int ih, int j, int ioc) { + for (int iw = 0; iw < r_; ++iw) { + int inp_oc = ob * oc_block_ + ioc; + int inp_ic = iic; + in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) + ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] + : 0.f; + wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] + += inp_v * g[j * r_ + iw]; + } + }); + + for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, + [&](int i, int j, int ioc) { + float t = 0; + for (int k = 0; k < r_; ++k) + t += g[i * r_ + k] + * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; + if (type_o == data_type::s8) { + const float scale = (D_mask == 1) + ? scales[0] + : scales[ob * oc_block_ + ioc]; + _out[(i * w_alpha_ + j) * Z + ioc] + = qz_b0<in_data_t, out_data_t>()( + (in_data_t)t, scale * adj_scale_); + } else { + _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; + } + }); + }} + } + + void reorder_to_aaOIoi(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int32_t *__restrict dst_bias = nullptr; + if (type_o == data_type::s8) { + const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; + const size_t bias_size = w_alpha_ * w_alpha_ * oc_; + + dst_bias = (int32_t *)(output + bias_shift); + utils::array_set((int32_t *)dst_bias, 0, bias_size); + } + int index = 0; + for (int u_h = 0; u_h < w_alpha_; u_h++) { + for (int u_w = 0; u_w < w_alpha_; u_w++) { + for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { + int u_h_shift = u_h * w_alpha_ * ic_ * oc_; + int u_w_shift = u_w * ic_ * oc_; + int u_h_shift_b = u_h * w_alpha_ * oc_; + int u_w_shift_b = u_w * oc_; + int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int _i = ib * ic_block_; + int _o = ob * oc_block_; + int ic_shift = (_i + i) * oc_; + int oc_shift = (_o + o); + int ic_block_shift = ib * oc_block_ * ic_block_ + i; + int src_offset = + u_h_shift + u_w_shift + ic_shift + oc_shift; + int dst_offset = u_h_shift + u_w_shift + oc_block_shift + + ic_block_shift; + + output[dst_offset] = tmp_wei[src_offset]; + if (type_o == data_type::s8) { + int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; + if (index != unsign_val_in_wino_domain_) + dst_bias[bias_offset] + -= (128 * (int32_t)output[dst_offset]); + else + dst_bias[bias_offset] = 0; + } + }} + }); + index++; + }} + } + + void reorder_to_aaOio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, + [&](int u_h, int u_w, int ob) { + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + for (int o = 0; o < oc_block_; o++) { + int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ + + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); + + int dst_offset + = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + ob * nb_ic_ * ic_block_ * oc_block_ + + ib * ic_block_ * oc_block_ + i * oc_block_ + o; + output[dst_offset] = tmp_wei[src_offset]; + }}} + }); + } + + void reorder_to_aaOBiOo(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, + [&](int u_h, int u_w, int occ) { + for (int ib = 0; ib < nb_ic_; ib++) { + out_data_t *__restrict wei_ptr = output + + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) + * oc2_block_ * ic_block_ * oc_block_; + int wei_offset = 0; + for (int i = 0; i < ic_block_; i++) { + for (int ob2 = 0; ob2 < oc2_block_; ob2++) { + for (int o = 0; o < oc_block_; o++) { + int icp = ib * ic_block_ + i; + int ocp = + occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + wei_ptr[wei_offset + o] = tmp_wei[src_offset]; + } + wei_offset += oc_block_; + }} + } + }); + } + + void reorder_to_OBaaIBOIio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int ic_chunks = nb_ic_ / ic2_block_; + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, + [&](int occ, int u_h, int u_w) { + for (int icc = 0; icc < ic_chunks; icc++) { + for (int ob = 0; ob < oc2_block_; ob++) { + int ocp = (occ * oc2_block_ + ob) * oc_block_; + for (int ib = 0; ib < ic2_block_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int icp = (icc * ic2_block_ + ib) * ic_block_ + i; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + int wei_offset + = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) + * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ + + ib) * ic_block_ + i) * oc_block_; + for (int o = 0; o < oc_block_; o++) + output[wei_offset + o] = tmp_wei[src_offset + o]; + }} + }} + }); + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + + auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_wino_transform_space); + auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_wino_plain); + + transform(tmp_wei, input, wspace); + + /* reorder to winograd domain */ + switch (wino_format_) { + case mkldnn_wino_wei_aaOIoi: + reorder_to_aaOIoi(output, tmp_wei); break; + case mkldnn_wino_wei_aaOio: + reorder_to_aaOio(output, tmp_wei); break; + case mkldnn_wino_wei_aaOBiOo: + reorder_to_aaOBiOo(output, tmp_wei); break; + case mkldnn_wino_wei_OBaaIBOIio: + reorder_to_OBaaIBOIio(output, tmp_wei); break; + default: assert("Unknown wino format"); break; + } + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int r_, w_alpha_; + int ic_, oc_, or_ic_, or_oc_, kh_, kw_; + int oc_block_, ic_block_, oc2_block_, ic2_block_; + float adj_scale_; + int nb_oc_, nb_ic_; + mkldnn_wino_memory_format_t wino_format_; + int size_wino_wei_; + int size_wspace_; +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif |