/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <assert.h>
#include <math.h>

#include "c_types_map.hpp"
#include "type_helpers.hpp"

#include "cpu_batch_normalization_utils.hpp"
#include "jit_generator.hpp"

#include "nspc_batch_normalization.hpp"

// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
#if (defined __clang_major__) && (__clang_major__ >= 6)
#define SAFE_TO_USE_OMP_SIMD 0
#else
#define SAFE_TO_USE_OMP_SIMD 1
#endif

namespace mkldnn {
namespace impl {
namespace cpu {

using namespace memory_tracking::names;

void nspc_batch_normalization_fwd_t::execute_forward(
        const exec_ctx_t &ctx) const {
    const bool save_stats = pd()->is_training();
    const bool is_training = pd()->is_training();
    const bool fuse_bn_relu = pd()->fuse_bn_relu();
    const bool calculate_stats = !pd()->stats_is_src();
    const bool with_relu = pd()->with_relu_post_op();

    auto scratchpad = this->scratchpad(ctx);
    auto tmp_mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
    auto tmp_var = scratchpad.get<data_t>(key_bnorm_tmp_var);
    auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);

    auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
    auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);

    data_t *mean, *variance;
    if (!calculate_stats) {
        mean = const_cast<data_t *>(
                CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN));
        variance = const_cast<data_t *>(
                CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE));
    } else {
        if (save_stats) {
            mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
            variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
        } else {
            mean = tmp_mean;
            variance = tmp_var;
        }
    }

    auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
    auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);

    const dim_t N = pd()->MB();
    const dim_t C = pd()->C();
    const dim_t SP = pd()->H() * pd()->W() * pd()->D();

    const float eps = pd()->desc()->batch_norm_epsilon;
    const bool use_scaleshift = pd()->use_scaleshift();
    auto maybe_post_op
            = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };

    assert(mkldnn_thr_syncable());
    parallel(0, [&](const int ithr, const int nthr) {
        dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
        balance211(N, nthr, ithr, N_s, N_e);
        balance211(C, nthr, ithr, C_s, C_e);
        data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr;
        data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr;

        if (calculate_stats) {
            for (dim_t c = 0; c < C; c++)
                ws_reduce[C * ithr + c] = 0.;

            for (dim_t n = N_s; n < N_e; n++)
                for (dim_t sp = 0; sp < SP; sp++)
                    PRAGMA_OMP_SIMD()
                    for (dim_t c = 0; c < C; c++)
                        ws_reduce[C * ithr + c] += src[(size_t)n * SP * C
                            + sp * C + c];

            mkldnn_thr_barrier();

            for (dim_t c = C_s; c < C_e; c++) {
                mean[c] = 0;
                for (dim_t n = 0; n < nthr; n++)
                    mean[c] += ws_reduce[C * n + c];
                mean[c] /= SP * N;
            }

            mkldnn_thr_barrier();

            for (dim_t c = 0; c < C; c++) {
                mean_loc[c] = mean[c];
                ws_reduce[C * ithr + c] = 0.;
            }

            for (dim_t n = N_s; n < N_e; n++)
                for (dim_t sp = 0; sp < SP; sp++)
                    PRAGMA_OMP_SIMD()
                    for (dim_t c = 0; c < C; c++) {
                        data_t m = src[(size_t)n * SP * C + sp * C + c]
                            - mean_loc[c];
                        ws_reduce[C * ithr + c] += m * m;
                    }

            mkldnn_thr_barrier();

            for (dim_t c = C_s; c < C_e; c++) {
                variance[c] = 0;
                for (dim_t n = 0; n < nthr; n++)
                    variance[c] += ws_reduce[C * n + c];
                variance[c] /= SP * N;
            }

            mkldnn_thr_barrier();

            for (dim_t c = 0; c < C; c++)
                variance_loc[c] = variance[c];
        } else {
            variance_loc = variance;
            mean_loc = mean;
        }

        for (dim_t n = N_s; n < N_e; n++) {
            for (dim_t sp = 0; sp < SP; sp++) {
#if SAFE_TO_USE_OMP_SIMD
                PRAGMA_OMP_SIMD()
#endif
                for (dim_t c = 0; c < C; c++) {
                    data_t sqrt_variance = static_cast<data_t>(
                            sqrtf(variance_loc[c] + eps));
                    data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance;
                    data_t sv = use_scaleshift ? scaleshift[C + c] : 0;
                    size_t d_off = (size_t)n * SP * C + sp * C + c;
                    data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv;
                    if (fuse_bn_relu) {
                        if (bn_res <= 0) {
                            bn_res = 0;
                            if (is_training)
                                ws[d_off] = 0;
                        } else {
                            if (is_training)
                                ws[d_off] = 1;
                        }
                    }
                    dst[d_off] = maybe_post_op(bn_res);
                }
            }
        }
    });
}

void nspc_batch_normalization_bwd_t::execute_backward(
        const exec_ctx_t &ctx) const {
    auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
    auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
    auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
    auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
    auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
    auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);

    auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
    auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);

    auto scratchpad = this->scratchpad(ctx);
    auto tmp_diff_ss = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);

    if (diff_scaleshift == nullptr)
        diff_scaleshift = tmp_diff_ss;

    const dim_t N = pd()->MB();
    const dim_t C = pd()->C();
    const dim_t SP = pd()->D() * pd()->H() * pd()->W();
    data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C;
    auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);

    const float eps = pd()->desc()->batch_norm_epsilon;
    const bool use_scaleshift = pd()->use_scaleshift();
    const bool calculate_diff_stats = !pd()->use_global_stats();
    const bool fuse_bn_relu = pd()->fuse_bn_relu();

    assert(mkldnn_thr_syncable());
    parallel(0, [&](const int ithr, const int nthr) {
        dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
        balance211(N, nthr, ithr, N_s, N_e);
        balance211(C, nthr, ithr, C_s, C_e);

        data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr;
        data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr);

        for (dim_t c = 0; c < C; c++) {
            ws_reduce[C * ithr + c] = 0.;
            ws_reduce[C * nthr + C * ithr + c] = 0.;
        }

        for (dim_t n = N_s; n < N_e; n++)
            for (dim_t sp = 0; sp < SP; sp++)
#if SAFE_TO_USE_OMP_SIMD
                PRAGMA_OMP_SIMD()
#endif
                for (dim_t c = 0; c < C; c++) {
                    const size_t d_off = (size_t)n * SP * C + sp * C + c;
                    data_t dd;
                    if (fuse_bn_relu)
                        dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
                    else
                        dd = diff_dst[d_off];
                    ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd;
                    ws_reduce[C * nthr + C * ithr + c] += dd;
                }

        mkldnn_thr_barrier();

        for (dim_t c = C_s; c < C_e; c++) {
            data_t sqrt_variance
                    = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
            diff_gamma[c] = 0;
            diff_beta[c] = 0;
            for (dim_t n = 0; n < nthr; n++) {
                diff_gamma[c] += ws_reduce[C * n + c];
                diff_beta[c] += ws_reduce[C * nthr + C * n + c];
            }
            diff_gamma[c] *= sqrt_variance;
        }

        mkldnn_thr_barrier();

        for (dim_t c = 0; c < C; c++) {
            diff_gamma_loc[c] = diff_gamma[c];
            diff_beta_loc[c] = diff_beta[c];
        }

        for (dim_t n = N_s; n < N_e; n++) {
            for (dim_t sp = 0; sp < SP; sp++) {
#if SAFE_TO_USE_OMP_SIMD
                PRAGMA_OMP_SIMD()
#endif
                for (dim_t c = 0; c < C; c++) {
                    const size_t d_off = (size_t)n * SP * C + sp * C + c;
                    data_t gamma = use_scaleshift ? scaleshift[c] : 1;
                    data_t sqrt_variance
                            = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
                    data_t v_diff_src;
                    if (fuse_bn_relu)
                        v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
                    else
                        v_diff_src = diff_dst[d_off];
                    if (calculate_diff_stats) {
                        v_diff_src -= diff_beta_loc[c] / (SP * N)
                                + (src[d_off] - mean[c]) * diff_gamma_loc[c]
                                        * sqrt_variance / (SP * N);
                    }
                    v_diff_src *= gamma * sqrt_variance;
                    diff_src[d_off] = v_diff_src;
                }
            }
        }
    });
}

}
}
}

// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s