/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef PRIMITIVE_DESC_HPP
#define PRIMITIVE_DESC_HPP

#include "mkldnn.h"

#include "c_types_map.hpp"
#include "memory_tracking.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
#include "primitive_attr.hpp"
#include "verbose.hpp"

struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
    using md_t = mkldnn::impl::memory_desc_t;

    mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
            const mkldnn::impl::primitive_attr_t *attr,
            mkldnn::impl::primitive_kind_t kind)
        : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }

    mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
            mkldnn::impl::primitive_kind_t kind)
        : engine_(engine), kind_(kind) { info_[0] = '\0'; }

    virtual mkldnn_primitive_desc *clone() const = 0;
    virtual ~mkldnn_primitive_desc() {}

    const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
    mkldnn::impl::engine_t *engine() const { return engine_; }
    mkldnn::impl::primitive_kind_t kind() const { return kind_; }

    virtual void init_info() {}
    const char *info() const { return info_; }

    mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
    { return scratchpad_registry_; }
    const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
    { return scratchpad_registry_; }
    virtual mkldnn::impl::engine_t *scratchpad_engine() const
    { return engine_; }

    virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }

    enum class arg_usage_t { unused, input, output };
    virtual arg_usage_t arg_usage(
            mkldnn::impl::primitive_arg_index_t arg) const {
        using mkldnn::impl::types::is_zero_md;
        if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
            return arg_usage_t::output;
        return arg_usage_t::unused;
    }

#   define DECLARE_MD_STUB(stub) \
    virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
    { return nullptr; }

    DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
    DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
    DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
    DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
    DECLARE_MD_STUB(workspace_md);
#   undef DECLARE_MD_STUB

    const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
        return idx == 0 ? &scratchpad_md_ : nullptr;
    }

    virtual void init_scratchpad_md() {
        auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
        mkldnn::impl::dims_t dims = { size };
        mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
                mkldnn::impl::data_type::u8, mkldnn_x);
    }

    /** returns the scratchpad size for the given scratchpad mode. */
    mkldnn::impl::dim_t scratchpad_size(
            mkldnn::impl::scratchpad_mode_t mode) const {
        if (mode != attr_.scratchpad_mode_) return 0;
        return scratchpad_registry().size();
    }

    virtual int n_inputs() const { return 0; }
    virtual int n_outputs() const { return 0; }

    virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
            void *result) const;

    virtual mkldnn::impl::status_t create_primitive(
            mkldnn::impl::primitive_t **primitive) const = 0;

    virtual const char *name() const { return "mkldnn_primitive_desc"; }

    /* static magic */

    template<typename pd_t>
    static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
            const mkldnn::impl::op_desc_t *adesc,
            const mkldnn::impl::primitive_attr_t *attr,
            mkldnn::impl::engine_t *engine,
            const mkldnn::impl::primitive_desc_t *hint_fwd) {
        using namespace mkldnn::impl;
        using namespace mkldnn::impl::status;
        using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
        if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
        assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
        auto hint =
            reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
        auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
        if (_pd == nullptr) return out_of_memory;
        if (_pd->init() != success) { delete _pd; return unimplemented; }
        _pd->init_info();
        _pd->init_scratchpad_md();
        *pd = _pd;
        return success;
    }

protected:
    mkldnn::impl::engine_t *engine_;
    mkldnn::impl::primitive_attr_t attr_;
    mkldnn::impl::primitive_kind_t kind_;

    mkldnn::impl::memory_desc_t scratchpad_md_;

    char info_[MKLDNN_VERBOSE_BUF_LEN];

    mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;

protected:
    /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
     * Expectation: this already set workspace, and this workspace should
     *              exactly match the one from fwd_pd */
    bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
        using namespace mkldnn::impl;
        if (!workspace_md()) return true; // the impl lives fine w/o workspace
        return fwd_pd && fwd_pd->workspace_md()
            && *fwd_pd->workspace_md() == *workspace_md();
    }
};

#define DECLARE_COMMON_PD_t(impl_name, ...) \
    virtual pd_t *clone() const override { return new pd_t(*this); } \
    virtual status_t create_primitive(primitive_t **p) const override { \
        double ms = get_msec(); \
        auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
        ms = get_msec() - ms; \
        if (mkldnn_verbose()->level >= 2) { \
            printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
            fflush(0); \
        } \
        return ret; \
    } \
    virtual const char *name() const override { return impl_name; }
#define DECLARE_COMMON_PD_T(impl_name, ...) \
    DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)

#endif

// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s