/******************************************************************************* * Copyright 2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include #include "mkldnn_traits.hpp" #include "mkldnn_thread.hpp" #include "type_helpers.hpp" #include "utils.hpp" #include "cpu_memory.hpp" namespace mkldnn { namespace impl { namespace cpu { using namespace mkldnn::impl; using namespace mkldnn::impl::data_type; using namespace mkldnn::impl::status; using namespace mkldnn::impl::format_tag; enum blk_kind_t { a, b, c, ab, ba, bc, cb }; template void typed_zero_pad_blk( const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { using data_t = typename prec_traits
::type; const auto &dims = m_d.dims(); const auto &pdims = m_d.padded_dims(); const auto &blk = m_d.blocking_desc(); auto dim_is_blocked = [&](int dim) { for (int i = 0; i < blk.inner_nblks; i++) if (blk.inner_idxs[i] == dim) return true; return false; }; bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1), C_blocked = dim_is_blocked(2); assert(blk.inner_nblks < 4); assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked) || (C_blocked && B_blocked)); const int a_tail_s = A_blocked ? dims[0] % blksize : 0; const int b_tail_s = B_blocked ? dims[1] % blksize : 0; const int c_tail_s = C_blocked ? dims[2] % blksize : 0; assert(a_tail_s || b_tail_s || c_tail_s); const int A = A_blocked ? pdims[0] / blksize : dims[0]; const int B = B_blocked ? pdims[1] / blksize : dims[1]; const int C = C_blocked ? pdims[2] / blksize : dims[2]; const int D = m_d.ndims() > 3 ? dims[3] : 1; const int E = m_d.ndims() > 4 ? dims[4] : 1; const int F = m_d.ndims() > 5 ? dims[5] : 1; const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1; auto zeroize_tail = [&](data_t *d, const int tail_s) { for (int b = tail_s; b < blksize; ++b) d[b] = 0; }; auto zeroize_tail_inner = [&](data_t *d, const int tail_s) { for (int b1 = 0; b1 < blksize; ++b1) for (int b2 = tail_s; b2 < blksize; ++b2) d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + b1 % inner_blk] = 0; }; auto zeroize_tail_outer = [&](data_t *d, const int tail_s) { for (int b1 = tail_s; b1 < blksize; ++b1) for (int b2 = 0; b2 < blksize; ++b2) d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + b1 % inner_blk] = 0; }; if (c_tail_s) { parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) { auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)]; if (blk_kind == c) zeroize_tail(x, c_tail_s); else if (blk_kind == bc) zeroize_tail_inner(x, c_tail_s); else if (blk_kind == cb) zeroize_tail_outer(x, c_tail_s); }); } if (b_tail_s) { parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) { auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)]; if (blk_kind == b) zeroize_tail(x, b_tail_s); else if (blk_kind == ab || blk_kind == cb) zeroize_tail_inner(x, b_tail_s); else if (blk_kind == ba || blk_kind == bc) zeroize_tail_outer(x, b_tail_s); }); } if (a_tail_s) { parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) { auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)]; if (blk_kind == a) zeroize_tail(x, a_tail_s); else if (blk_kind == ba) zeroize_tail_inner(x, a_tail_s); else if (blk_kind == ab) zeroize_tail_outer(x, a_tail_s); }); } } /* * all */ template void typed_zero_pad_generic_blocked( const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { const int ndims = m_d.ndims(); const auto &dims = m_d.dims(); const auto &pdims = m_d.padded_dims(); const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true); /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1] * | \ / * | --------------------- * has contiguous * padding * * step <-- D_k+1 * ... * D_ndims-1 * step_dim <-- k */ ptrdiff_t step = 1; int step_dim = ndims - 1; for (; step_dim >= 0; --step_dim) { if (dims[step_dim] != pdims[step_dim]) break; step *= dims[step_dim]; } assert(step_dim >= 0 && "no zero padding is required"); if (step_dim < 0) return; parallel_nd(nelems / step, [&](ptrdiff_t e1) { bool need_zero = false; ptrdiff_t idx = e1; for (int d = step_dim; d >= 0; --d) { if (idx % pdims[d] >= dims[d]) { need_zero = true; break; } idx /= pdims[d]; } if (need_zero) { for (ptrdiff_t e0 = 0; e0 < step; ++e0) data[m_d.off_l(e1 * step + e0, true)] = 0; } }); } template status_t cpu_memory_t::typed_zero_pad() const { const memory_desc_wrapper mdw(md()); if (mdw.format_kind() != format_kind::blocked) return unimplemented; if (mdw.nelems(false) == mdw.nelems(true)) return success; auto *data = (typename prec_traits
::type *)data_; auto blk = mdw.blocking_desc(); auto get_blksize = [&](int ind) { int blksize = 1; for (int i = 0; i < blk.inner_nblks; i++) { if (blk.inner_idxs[i] == ind) blksize *= blk.inner_blks[i]; } return blksize; }; const int blksize = get_blksize(blk.inner_idxs[0]); # define CASE(blksize_, blk_kind) \ do { \ if (blksize == blksize_) { \ typed_zero_pad_blk(mdw, data); \ return success; \ } \ } while(0) switch (blk.inner_nblks) { case 1: if (blk.inner_idxs[0] == 0) { CASE(4, a); CASE(8, a); CASE(16, a); } else if (blk.inner_idxs[0] == 1) { CASE(4, b); CASE(8, b); CASE(16, b); } break; case 2: case 3: if (!IMPLICATION(blk.inner_nblks == 3, blk.inner_idxs[0] == blk.inner_idxs[2])) break; if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) { CASE(4, ab); CASE(8, ab); CASE(16, ab); } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) { CASE(4, ba); CASE(8, ba); CASE(16, ba); } if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) { CASE(4, bc); CASE(8, bc); CASE(16, bc); } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) { CASE(4, cb); CASE(8, cb); CASE(16, cb); } break; default: break; } # undef CASE // the last line of defence typed_zero_pad_generic_blocked
(mdw, data); return success; } status_t cpu_memory_t::zero_pad() const { memory_desc_wrapper mdw(md()); const bool skip_zeroing = false || data_ == nullptr || mdw.is_zero() || !mdw.is_blocking_desc(); if (skip_zeroing) return success; switch (mdw.data_type()) { case f32: return typed_zero_pad(); case s32: return typed_zero_pad(); case s8: return typed_zero_pad(); case u8: return typed_zero_pad(); default: assert(!"memory is undefined"); return unimplemented; } return unimplemented; } } } }