diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/common')
65 files changed, 11287 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp new file mode 100644 index 0000000000..1a51d8562b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) { + bool args_ok = true + && !any_null(bnrm_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data, backward) + && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto bd = batch_normalization_desc_t(); + bd.primitive_kind = primitive_kind::batch_normalization; + bd.prop_kind = prop_kind; + + bd.data_desc = *data_desc; + bd.diff_data_desc = zero_md(); + if ( one_of(bd.prop_kind,backward_data, backward) ) + bd.diff_data_desc = *diff_data_desc; + + dims_t scaleshift_dims = { 2, data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2, + scaleshift_dims, data_type::f32, mkldnn_nc); + bd.diff_data_scaleshift_desc = zero_md(); + if (bd.prop_kind == backward) { + bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc; + } + + dims_t stats_dims = { data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims, + data_type::f32, mkldnn_x); + bd.variance_desc = bd.mean_desc; + bd.batch_norm_epsilon = epsilon; + + unsigned bnorm_flags = + mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu; + if ((~bnorm_flags & flags) != 0) return invalid_arguments; + + bd.flags = flags; + + bool consistency = true + && utils::one_of(bd.data_desc.ndims, 2, 4, 5); + if (bd.prop_kind == backward_data) + consistency = consistency + && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5) + && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims, + bd.diff_data_desc.ndims); + if (!consistency) return invalid_arguments; + + *bnrm_desc = bd; + return success; +} +} + +status_t mkldnn_batch_normalization_forward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, float epsilon, unsigned flags) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr, + epsilon, flags); +} + +status_t mkldnn_batch_normalization_backward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc, + float epsilon, unsigned flags) { + if (!one_of(prop_kind, backward, backward_data)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc, + epsilon, flags); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp new file mode 100644 index 0000000000..f61410b33c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp @@ -0,0 +1,240 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef BATCH_NORMALIZATION_PD_HPP +#define BATCH_NORMALIZATION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct batch_normalization_fwd_pd_t; + +struct batch_normalization_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::batch_normalization; + + batch_normalization_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , stat_md_(desc_.mean_desc) + , scaleshift_md_(desc_.data_scaleshift_desc) + , ws_md_() + {} + + const batch_normalization_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::batch_normalization_d: + *(const batch_normalization_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common batch_normalization aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return desc_.data_desc.ndims; } + + bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; } + bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; } + bool use_global_stats() const + { return desc_.flags & mkldnn_use_global_stats; } + bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; } + bool with_relu_post_op() const { + const auto &p = this->attr()->post_ops_; + return p.len_ == 1 && p.entry_[0].is_relu(true, true); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + bool is_bwd() const { return !this->is_fwd(); } + bool is_training() const + { return desc_.prop_kind == prop_kind::forward_training; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + batch_normalization_desc_t desc_; + const batch_normalization_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t stat_md_; + memory_desc_t scaleshift_md_; + + memory_desc_t ws_md_; + + void init_default_ws(size_t bits_per_element) { + const auto data_mdw = memory_desc_wrapper(data_md_); + + const dim_t data_nelems = data_mdw.nelems(true); + const dim_t bits_per_byte = 8; + const dims_t ws_sz = { (dim_t)utils::div_up( + data_nelems * bits_per_element, bits_per_byte) }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8, + format_tag::x); + } + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_fwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_fwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input; + if (arg == MKLDNN_ARG_DST) return arg_usage_t::output; + + if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) { + if (stats_is_src()) return arg_usage_t::input; + if (!stats_is_src() && is_training()) return arg_usage_t::output; + return arg_usage_t::unused; + } + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (!stats_is_src() && is_training() && (index == 1 || index == 2)) + return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const + { return stats_is_src() ? src_md(1) : dst_md(1); } + + virtual int n_inputs() const override + { return 1 + 2 * stats_is_src() + use_scaleshift(); } + virtual int n_outputs() const override + { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); } +}; + +struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_bwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_bwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN, + MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override + { return index == 0 ? &diff_scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const { return src_md(1); } + + virtual int n_inputs() const override + { return 4 + use_scaleshift() + fuse_bn_relu(); } + virtual int n_outputs() const override + { return 1 + (desc_.prop_kind == prop_kind::backward); } + +protected: + memory_desc_t diff_data_md_; + memory_desc_t diff_scaleshift_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp new file mode 100644 index 0000000000..3d43a0fbee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp @@ -0,0 +1,550 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_MAPPING_HPP +#define TYPE_MAPPING_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { + +// TODO: autogenerate this + +using dim_t = mkldnn_dim_t; +using dims_t = mkldnn_dims_t; +using stride_t = mkldnn_dim_t; +using strides_t = mkldnn_strides_t; + +using status_t = mkldnn_status_t; +namespace status { + const status_t success = mkldnn_success; + const status_t out_of_memory = mkldnn_out_of_memory; + const status_t try_again = mkldnn_try_again; + const status_t invalid_arguments = mkldnn_invalid_arguments; + const status_t not_ready = mkldnn_not_ready; + const status_t unimplemented = mkldnn_unimplemented; + const status_t iterator_ends = mkldnn_iterator_ends; + const status_t runtime_error = mkldnn_runtime_error; + const status_t not_required = mkldnn_not_required; +} + +using prop_kind_t = mkldnn_prop_kind_t; +namespace prop_kind { + const prop_kind_t undef = mkldnn_prop_kind_undef; + const prop_kind_t forward_training = mkldnn_forward_training; + const prop_kind_t forward_inference = mkldnn_forward_inference; + const prop_kind_t forward_scoring = mkldnn_forward_scoring; + const prop_kind_t forward = mkldnn_forward; + const prop_kind_t backward = mkldnn_backward; + const prop_kind_t backward_data = mkldnn_backward_data; + const prop_kind_t backward_weights = mkldnn_backward_weights; + const prop_kind_t backward_bias = mkldnn_backward_bias; +} + +using alg_kind_t = mkldnn_alg_kind_t; +namespace alg_kind { + const alg_kind_t undef = mkldnn_alg_kind_undef; + const alg_kind_t convolution_auto = mkldnn_convolution_auto; + const alg_kind_t convolution_direct = mkldnn_convolution_direct; + const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; + const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct; + const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd; + const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; + const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh; + const alg_kind_t eltwise_elu = mkldnn_eltwise_elu; + const alg_kind_t eltwise_square = mkldnn_eltwise_square; + const alg_kind_t eltwise_abs = mkldnn_eltwise_abs; + const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt; + const alg_kind_t eltwise_linear = mkldnn_eltwise_linear; + const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu; + const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu; + const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic; + const alg_kind_t pooling_max = mkldnn_pooling_max; + const alg_kind_t pooling_avg = mkldnn_pooling_avg; + const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding; + const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; + const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; + const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; + const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; + const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; + const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; + const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset; +} + +using data_type_t = mkldnn_data_type_t; +namespace data_type { + const data_type_t undef = mkldnn_data_type_undef; + const data_type_t f32 = mkldnn_f32; + const data_type_t s32 = mkldnn_s32; + const data_type_t s8 = mkldnn_s8; + const data_type_t u8 = mkldnn_u8; +} + +using scratchpad_mode_t = mkldnn_scratchpad_mode_t; +namespace scratchpad_mode { + const scratchpad_mode_t library = mkldnn_scratchpad_mode_library; + const scratchpad_mode_t user = mkldnn_scratchpad_mode_user; +} + +using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t; +namespace rnn_packed_format { + const rnn_packed_format_t undef = mkldnn_packed_format_undef; + const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p; + const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p; +} + +using format_kind_t = mkldnn_format_kind_t; +namespace format_kind { + const format_kind_t undef = mkldnn_format_kind_undef; + const format_kind_t any = mkldnn_format_kind_any; + const format_kind_t blocked = mkldnn_blocked; + const format_kind_t wino = mkldnn_format_kind_wino; + const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed; +} + +using format_tag_t = mkldnn_format_tag_t; +namespace format_tag { + const format_tag_t undef = mkldnn_format_tag_undef; + const format_tag_t any = mkldnn_format_tag_any; + const format_tag_t a = mkldnn_a; + const format_tag_t ab = mkldnn_ab; + const format_tag_t abc = mkldnn_abc; + const format_tag_t abcd = mkldnn_abcd; + const format_tag_t abcde = mkldnn_abcde; + const format_tag_t abcdef = mkldnn_abcdef; + const format_tag_t abdec = mkldnn_abdec; + const format_tag_t acb = mkldnn_acb; + const format_tag_t acbde = mkldnn_acbde; + const format_tag_t acdb = mkldnn_acdb; + const format_tag_t acdeb = mkldnn_acdeb; + const format_tag_t ba = mkldnn_ba; + const format_tag_t bac = mkldnn_bac; + const format_tag_t bacd = mkldnn_bacd; + const format_tag_t bcda = mkldnn_bcda; + const format_tag_t cba = mkldnn_cba; + const format_tag_t cdba = mkldnn_cdba; + const format_tag_t cdeba = mkldnn_cdeba; + const format_tag_t decab = mkldnn_decab; + const format_tag_t Abc16a = mkldnn_Abc16a; + const format_tag_t ABc16a16b = mkldnn_ABc16a16b; + const format_tag_t aBc16b = mkldnn_aBc16b; + const format_tag_t ABc16b16a = mkldnn_ABc16b16a; + const format_tag_t Abc4a = mkldnn_Abc4a; + const format_tag_t aBc4b = mkldnn_aBc4b; + const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b; + const format_tag_t ABc4b4a = mkldnn_ABc4b4a; + const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a; + const format_tag_t ABc8a8b = mkldnn_ABc8a8b; + const format_tag_t aBc8b = mkldnn_aBc8b; + const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b; + const format_tag_t ABc8b8a = mkldnn_ABc8b8a; + const format_tag_t Abcd16a = mkldnn_Abcd16a; + const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b; + const format_tag_t aBcd16b = mkldnn_aBcd16b; + const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a; + const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c; + const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b; + const format_tag_t Abcd4a = mkldnn_Abcd4a; + const format_tag_t aBcd4b = mkldnn_aBcd4b; + const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b; + const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a; + const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c; + const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b; + const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a; + const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b; + const format_tag_t aBcd8b = mkldnn_aBcd8b; + const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b; + const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b; + const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a; + const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c; + const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c; + const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b; + const format_tag_t Abcde16a = mkldnn_Abcde16a; + const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b; + const format_tag_t aBcde16b = mkldnn_aBcde16b; + const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a; + const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c; + const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b; + const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c; + const format_tag_t Abcde4a = mkldnn_Abcde4a; + const format_tag_t aBcde4b = mkldnn_aBcde4b; + const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a; + const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c; + const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c; + const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b; + const format_tag_t Abcde8a = mkldnn_Abcde8a; + const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b; + const format_tag_t aBcde8b = mkldnn_aBcde8b; + const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b; + const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b; + const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a; + const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c; + const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c; + const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b; + const format_tag_t aBcdef16b = mkldnn_aBcdef16b; + const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c; + const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b; + const format_tag_t aBcdef4b = mkldnn_aBcdef4b; + const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b; + const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c; + const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c; + const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b; + const format_tag_t aBdc16b = mkldnn_aBdc16b; + const format_tag_t aBdc4b = mkldnn_aBdc4b; + const format_tag_t aBdc8b = mkldnn_aBdc8b; + const format_tag_t aBdec16b = mkldnn_aBdec16b; + const format_tag_t aBdec4b = mkldnn_aBdec4b; + const format_tag_t aBdec8b = mkldnn_aBdec8b; + const format_tag_t aBdefc16b = mkldnn_aBdefc16b; + const format_tag_t aBdefc4b = mkldnn_aBdefc4b; + const format_tag_t aBdefc8b = mkldnn_aBdefc8b; + const format_tag_t Acb16a = mkldnn_Acb16a; + const format_tag_t Acb4a = mkldnn_Acb4a; + const format_tag_t Acb8a = mkldnn_Acb8a; + const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c; + const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c; + const format_tag_t Acdb16a = mkldnn_Acdb16a; + const format_tag_t Acdb4a = mkldnn_Acdb4a; + const format_tag_t Acdb8a = mkldnn_Acdb8a; + const format_tag_t Acdeb16a = mkldnn_Acdeb16a; + const format_tag_t Acdeb4a = mkldnn_Acdeb4a; + const format_tag_t Acdeb8a = mkldnn_Acdeb8a; + const format_tag_t BAc16a16b = mkldnn_BAc16a16b; + const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b; + const format_tag_t last = mkldnn_format_tag_last; + + const format_tag_t x = mkldnn_x; + const format_tag_t nc = mkldnn_nc; + const format_tag_t cn = mkldnn_cn; + const format_tag_t ncw = mkldnn_ncw; + const format_tag_t nwc = mkldnn_nwc; + const format_tag_t nchw = mkldnn_nchw; + const format_tag_t nhwc = mkldnn_nhwc; + const format_tag_t chwn = mkldnn_chwn; + const format_tag_t ncdhw = mkldnn_ncdhw; + const format_tag_t ndhwc = mkldnn_ndhwc; + const format_tag_t oi = mkldnn_oi; + const format_tag_t io = mkldnn_io; + const format_tag_t oiw = mkldnn_oiw; + const format_tag_t wio = mkldnn_wio; + const format_tag_t oihw = mkldnn_oihw; + const format_tag_t hwio = mkldnn_hwio; + const format_tag_t ihwo = mkldnn_ihwo; + const format_tag_t iohw = mkldnn_iohw; + const format_tag_t oidhw = mkldnn_oidhw; + const format_tag_t dhwio = mkldnn_dhwio; + const format_tag_t goiw = mkldnn_goiw; + const format_tag_t goihw = mkldnn_goihw; + const format_tag_t hwigo = mkldnn_hwigo; + const format_tag_t giohw = mkldnn_giohw; + const format_tag_t goidhw = mkldnn_goidhw; + const format_tag_t tnc = mkldnn_tnc; + const format_tag_t ntc = mkldnn_ntc; + const format_tag_t ldsnc = mkldnn_ldsnc; + const format_tag_t ldigo = mkldnn_ldigo; + const format_tag_t ldgoi = mkldnn_ldgoi; + const format_tag_t ldgo = mkldnn_ldgo; + const format_tag_t nCdhw16c = mkldnn_nCdhw16c; + const format_tag_t nCdhw4c = mkldnn_nCdhw4c; + const format_tag_t nCdhw8c = mkldnn_nCdhw8c; + const format_tag_t nChw16c = mkldnn_nChw16c; + const format_tag_t nChw4c = mkldnn_nChw4c; + const format_tag_t nChw8c = mkldnn_nChw8c; + const format_tag_t nCw16c = mkldnn_nCw16c; + const format_tag_t nCw4c = mkldnn_nCw4c; + const format_tag_t nCw8c = mkldnn_nCw8c; + const format_tag_t IOw16o16i = mkldnn_IOw16o16i; + const format_tag_t OIw16i16o = mkldnn_OIw16i16o; + const format_tag_t OIw16o16i = mkldnn_OIw16o16i; + const format_tag_t Oiw16o = mkldnn_Oiw16o; + const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i; + const format_tag_t OIw4i4o = mkldnn_OIw4i4o; + const format_tag_t Oiw4o = mkldnn_Oiw4o; + const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i; + const format_tag_t OIw8i8o = mkldnn_OIw8i8o; + const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o; + const format_tag_t OIw8o8i = mkldnn_OIw8o8i; + const format_tag_t Owi16o = mkldnn_Owi16o; + const format_tag_t Owi4o = mkldnn_Owi4o; + const format_tag_t Owi8o = mkldnn_Owi8o; + const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i; + const format_tag_t Ohwi16o = mkldnn_Ohwi16o; + const format_tag_t Ohwi4o = mkldnn_Ohwi4o; + const format_tag_t Ohwi8o = mkldnn_Ohwi8o; + const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o; + const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i; + const format_tag_t Oihw16o = mkldnn_Oihw16o; + const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i; + const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o; + const format_tag_t Oihw4o = mkldnn_Oihw4o; + const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i; + const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o; + const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o; + const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i; + const format_tag_t Odhwi16o = mkldnn_Odhwi16o; + const format_tag_t Odhwi4o = mkldnn_Odhwi4o; + const format_tag_t Odhwi8o = mkldnn_Odhwi8o; + const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o; + const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i; + const format_tag_t Oidhw16o = mkldnn_Oidhw16o; + const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o; + const format_tag_t Oidhw4o = mkldnn_Oidhw4o; + const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i; + const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o; + const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i; + const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i; + const format_tag_t Goiw16g = mkldnn_Goiw16g; + const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o; + const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i; + const format_tag_t gOiw16o = mkldnn_gOiw16o; + const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i; + const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o; + const format_tag_t gOiw4o = mkldnn_gOiw4o; + const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i; + const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o; + const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; + const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i; + const format_tag_t gOwi16o = mkldnn_gOwi16o; + const format_tag_t gOwi4o = mkldnn_gOwi4o; + const format_tag_t gOwi8o = mkldnn_gOwi8o; + const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i; + const format_tag_t gOhwi16o = mkldnn_gOhwi16o; + const format_tag_t gOhwi4o = mkldnn_gOhwi4o; + const format_tag_t gOhwi8o = mkldnn_gOhwi8o; + const format_tag_t Goihw16g = mkldnn_Goihw16g; + const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o; + const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i; + const format_tag_t gOihw16o = mkldnn_gOihw16o; + const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i; + const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i; + const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o; + const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i; + const format_tag_t gOihw4o = mkldnn_gOihw4o; + const format_tag_t Goihw8g = mkldnn_Goihw8g; + const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i; + const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o; + const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o; + const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i; + const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o; + const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o; + const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o; + const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o; + const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i; + const format_tag_t gOidhw16o = mkldnn_gOidhw16o; + const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o; + const format_tag_t gOidhw4o = mkldnn_gOidhw4o; + const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i; + const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o; + const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i; +} + +using memory_extra_flags_t = mkldnn_memory_extra_flags_t; +namespace memory_extra_flags { + const memory_extra_flags_t none = mkldnn_memory_extra_flag_none; + const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8; + const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust; +} + +using padding_kind_t = mkldnn_padding_kind_t; +namespace padding_kind { + const padding_kind_t padding_zero = mkldnn_padding_zero; +} + +using engine_kind_t = mkldnn_engine_kind_t; +namespace engine_kind { + const engine_kind_t any_engine = mkldnn_any_engine; + const engine_kind_t cpu = mkldnn_cpu; +} + +using primitive_kind_t = mkldnn_primitive_kind_t; +namespace primitive_kind { + const primitive_kind_t undefined = mkldnn_undefined_primitive; + const primitive_kind_t reorder = mkldnn_reorder; + const primitive_kind_t concat = mkldnn_concat; + const primitive_kind_t sum = mkldnn_sum; + const primitive_kind_t convolution = mkldnn_convolution; + const primitive_kind_t deconvolution = mkldnn_deconvolution; + const primitive_kind_t shuffle = mkldnn_shuffle; + const primitive_kind_t eltwise = mkldnn_eltwise; + const primitive_kind_t softmax = mkldnn_softmax; + const primitive_kind_t pooling = mkldnn_pooling; + const primitive_kind_t lrn = mkldnn_lrn; + const primitive_kind_t batch_normalization = mkldnn_batch_normalization; + const primitive_kind_t inner_product = mkldnn_inner_product; + const primitive_kind_t rnn = mkldnn_rnn; +} + +using query_t = mkldnn_query_t; +namespace query { + const query_t undef = mkldnn_query_undef; + + const query_t engine = mkldnn_query_engine; + const query_t primitive_kind = mkldnn_query_primitive_kind; + + const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32; + const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32; + + const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64; + const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64; + + const query_t scratchpad_engine = mkldnn_query_scratchpad_engine; + + const query_t impl_info_str = mkldnn_query_impl_info_str; + + const query_t some_d = mkldnn_query_some_d; + const query_t op_d = mkldnn_query_op_d; + const query_t convolution_d = mkldnn_query_convolution_d; + const query_t deconvolution_d = mkldnn_query_deconvolution_d; + const query_t shuffle_d = mkldnn_query_shuffle_d; + const query_t eltwise_d = mkldnn_query_eltwise_d; + const query_t softmax_d = mkldnn_query_softmax_d; + const query_t pooling_d = mkldnn_query_pooling_d; + const query_t lrn_d = mkldnn_query_lrn_d; + const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; + const query_t inner_product_d = mkldnn_query_inner_product_d; + const query_t rnn_d = mkldnn_query_rnn_d; + + const query_t some_md = mkldnn_query_some_md; + const query_t src_md = mkldnn_query_src_md; + const query_t diff_src_md = mkldnn_query_diff_src_md; + const query_t weights_md = mkldnn_query_weights_md; + const query_t diff_weights_md = mkldnn_query_diff_weights_md; + const query_t dst_md = mkldnn_query_dst_md; + const query_t diff_dst_md = mkldnn_query_diff_dst_md; + + const query_t workspace_md = mkldnn_query_workspace_md; + const query_t scratchpad_md = mkldnn_query_scratchpad_md; +} + +using blocking_desc_t = mkldnn_blocking_desc_t; +using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t; +using wino_desc_t = mkldnn_wino_desc_t; +using memory_extra_desc_t = mkldnn_memory_extra_desc_t; +using memory_desc_t = mkldnn_memory_desc_t; +using convolution_desc_t = mkldnn_convolution_desc_t; +using deconvolution_desc_t = mkldnn_deconvolution_desc_t; +using shuffle_desc_t = mkldnn_shuffle_desc_t; +using pooling_desc_t = mkldnn_pooling_desc_t; +using eltwise_desc_t = mkldnn_eltwise_desc_t; +using softmax_desc_t = mkldnn_softmax_desc_t; +using lrn_desc_t = mkldnn_lrn_desc_t; +using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; +using inner_product_desc_t = mkldnn_inner_product_desc_t; + +using rnn_direction_t = mkldnn_rnn_direction_t; +using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; +using rnn_desc_t = mkldnn_rnn_desc_t; + +/* C op_desc_t, which eventually are just (void*) */ +using c_op_desc_t = mkldnn_op_desc_t; +using const_c_op_desc_t = const_mkldnn_op_desc_t; + +struct op_desc_t { + union { + primitive_kind_t kind; + convolution_desc_t convolution; + deconvolution_desc_t deconvolution; + shuffle_desc_t shuffle; + pooling_desc_t pooling; + eltwise_desc_t eltwise; + softmax_desc_t softmax; + lrn_desc_t lrn; + batch_normalization_desc_t batch_normalization; + inner_product_desc_t inner_product; + rnn_desc_t rnn; + }; + + op_desc_t(const primitive_kind_t &_): kind(_) {} + +# define DECL_CTOR_AND_CONVERTERS(c_type, name) \ + op_desc_t(const c_type &_): name(_) {} \ + static op_desc_t *convert_from_c(c_type *_) \ + { return reinterpret_cast<op_desc_t*>(_); } \ + static const op_desc_t *convert_from_c(const c_type *_) \ + { return reinterpret_cast<const op_desc_t*>(_); } + + DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution); + DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle); + DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling); + DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise); + DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax); + DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn); + DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization); + DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product); + DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn); + +# undef DECL_CTOR_AND_CONVERTERS +}; + +using engine_t = mkldnn_engine; +using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator; +using primitive_desc_t = mkldnn_primitive_desc; +using primitive_attr_t = mkldnn_primitive_attr; +using post_ops_t = mkldnn_post_ops; +using memory_t = mkldnn_memory; +using primitive_t = mkldnn_primitive; + +using primitive_arg_index_t = int; + +using stream_flags_t = mkldnn_stream_flags_t; +namespace stream_flags { + const stream_flags_t default_flags = mkldnn_stream_default_flags; +} +using stream_t = mkldnn_stream; + +/* forward declaration of the internal primitive_desc types */ +struct batch_normalization_bwd_pd_t; +struct batch_normalization_fwd_pd_t; +struct batch_normalization_pd_t; +struct concat_pd_t; +struct convolution_bwd_data_pd_t; +struct convolution_bwd_weights_pd_t; +struct convolution_fwd_pd_t; +struct convolution_pd_t; +struct deconvolution_bwd_data_pd_t; +struct deconvolution_bwd_weights_pd_t; +struct deconvolution_fwd_pd_t; +struct deconvolution_pd_t; +struct eltwise_bwd_pd_t; +struct eltwise_fwd_pd_t; +struct eltwise_pd_t; +struct inner_product_bwd_data_pd_t; +struct inner_product_bwd_weights_pd_t; +struct inner_product_fwd_pd_t; +struct inner_product_pd_t; +struct lrn_bwd_pd_t; +struct lrn_fwd_pd_t; +struct lrn_pd_t; +struct pooling_bwd_pd_t; +struct pooling_fwd_pd_t; +struct pooling_pd_t; +struct reorder_pd_t; +struct rnn_bwd_pd_t; +struct rnn_fwd_pd_t; +struct rnn_pd_t; +struct shuffle_pd_t; +struct softmax_bwd_pd_t; +struct softmax_fwd_pd_t; +struct softmax_pd_t; +struct sum_pd_t; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp new file mode 100644 index 0000000000..ed4c35c6e9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "concat_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds, + const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(concat_pd, src_mds) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + int concat_dim_sz = dims[concat_dim]; + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (d == concat_dim) continue; + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + concat_dim_sz += src_mds[i].dims[concat_dim]; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != + (d == concat_dim ? concat_dim_sz : dims[d])) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.dims[concat_dim] = concat_dim_sz; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd); + + for (auto c = engine->get_concat_implementation_list(); *c; ++c) { + if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds) + == success) { + (*c_pd)->init_info(); + (*c_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp new file mode 100644 index 0000000000..29311927e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp @@ -0,0 +1,211 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONCAT_PD_HPP +#define CONCAT_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct concat_pd_t: public primitive_desc_t { + concat_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::concat) + , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md) + { + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + concat_pd_t(const concat_pd_t &rhs) = default; + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + int concat_dim() const { return concat_dim_; } + + const memory_desc_t *src_image_md(int index = 0) const + { return index < n_inputs() ? &src_image_mds_[index] : nullptr; } + +protected: + int n_, concat_dim_; + memory_desc_t dst_md_; + nstl::vector<memory_desc_t> src_mds_; + + /* contains images of srcs in the dst memory (if possible) + * Lives here to simplify some implementations. An implementation might + * use this auxiliary array iff init() returned success */ + nstl::vector<memory_desc_t> src_image_mds_; + +protected: + /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */ + status_t init() { + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + if (!i_d.is_blocking_desc() || i_d.is_additional_buffer()) + return status::unimplemented; + } + + const int ndims = dst_md_.ndims; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) return status; + src_image_mds_.push_back(src_img_d); + current_concat_dim_offset += dim; + } + + return status::success; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + const int ndims = dst_md_.ndims; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain or it is not possible to create a + * blocked format for the output, pick the format of the plain input + * - If this fails as well, use plain layout (abcd...) + */ + status_t status = status::unimplemented; + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && !src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + if (status == status::success) break; + } + } + + if (status == status::success) { + /* check if we can create a sub-memory for the dst */ + bool desired_format_ok = true; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) { + desired_format_ok = false; + break; + } + current_concat_dim_offset += dim; + } + + if (!desired_format_ok) + status = status::unimplemented; + } + + /* if no success so far, try using the format of the first plain input */ + if (status != status::success) { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + memory_desc_wrapper(src_mds_[0]).blocking_desc()); + if (status == status::success) return status; + } + } + } + + /* the last line of defense: use plain abcd... format */ + if (status != status::success) + status = memory_desc_init_by_strides(dst_md_, nullptr); + + return status; + } +}; + +#define DECLARE_CONCAT_PD_t(impl_name, ...) \ + static status_t create(concat_pd_t **concat_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, int concat_dim, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_CONCAT_PD_T(impl_name, ...) \ + DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp new file mode 100644 index 0000000000..0c5c02bcd1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto cd = convolution_desc_t(); + cd.primitive_kind = primitive_kind::convolution; + cd.prop_kind = prop_kind; + cd.alg_kind = alg_kind; + + cd.diff_src_desc = cd.src_desc = zero_md(); + cd.diff_dst_desc = cd.dst_desc = zero_md(); + cd.diff_weights_desc = cd.weights_desc = zero_md(); + cd.diff_bias_desc = cd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc; + (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) = + *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(cd.strides, strides, sp_dims); + utils::array_copy(cd.padding[0], padding_l, sp_dims); + utils::array_copy(cd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(cd.dilates, dilates, sp_dims); + else + utils::array_set(cd.dilates, 0, sp_dims); + + cd.padding_kind = padding_kind; + cd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + const int bias_dim = prop_kind == backward_data + ? src_desc->dims[1] + : dst_desc->dims[1]; + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == bias_dim : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) + { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = cd.dilates[i - 2]; + int pad_l = padding_l[i - 2]; + int pad_r = padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + if (str < 1) return invalid_arguments; + consistency = consistency + && dil >= 0 + && pad_l >= 0 + && pad_r + str > 0 + && (src - ker_range + pad_l + pad_r) / str + 1 == dst; + } + if (!consistency) return invalid_arguments; + + *conv_desc = cd; + return success; +} +} +} + +status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_forward_desc_init( + convolution_desc_t *conv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + nullptr, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + dilates, padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp new file mode 100644 index 0000000000..9604e0acf5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "convolution_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); } +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); } +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); } +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp new file mode 100644 index 0000000000..b10c36db49 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONVOLUTION_PD_HPP +#define CONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind); + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); + +struct convolution_fwd_pd_t; + +struct convolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::convolution; + + convolution_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const convolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits<base_pkind>::query_d: + *(const convolution_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common conv aux functions */ + + dim_t MB() const { return _src_md()->dims[0]; } + + dim_t IC() const { return _src_md()->dims[1]; } + dim_t OC() const { return _dst_md()->dims[1]; } + dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } + + dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } + dim_t IW() const { return _src_md()->dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } + dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } + dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } + dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return _src_md()->ndims; } + + bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } + bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*_src_md()); + const auto d_d = memory_desc_wrapper(*_dst_md()); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + convolution_desc_t desc_; + const convolution_fwd_pd_t *hint_fwd_pd_; + + bool set_default_formats_common_template( + memory_desc_t &src_md, format_tag_t src_tag, + memory_desc_t &wei_md, format_tag_t wei_tag, + memory_desc_t &dst_md, format_tag_t dst_tag, + memory_desc_t &bia_md) { + using namespace format_tag; + +# define IS_OK(f) \ + do { if ((f) != status::success) return false; } while(0) + if (src_md.format_kind == format_kind::any + && !utils::one_of(src_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(src_md, src_tag)); + if (dst_md.format_kind == format_kind::any + && !utils::one_of(dst_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); + if (wei_md.format_kind == format_kind::any + && !utils::one_of(wei_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); + if (with_bias() && bia_md.format_kind == format_kind::any) + IS_OK(memory_desc_init_by_tag(bia_md, x)); +# undef IS_OK + + return true; + } + + bool set_default_alg_kind(alg_kind_t alg_kind) { + assert(utils::one_of(alg_kind, alg_kind::convolution_direct, + alg_kind::convolution_winograd)); + if (desc_.alg_kind == alg_kind::convolution_auto) + desc_.alg_kind = alg_kind; + return desc_.alg_kind == alg_kind; + } + + bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, + data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { + bool ok = true + && (src_dt == data_type::undef || _src_md()->data_type == src_dt) + && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) + && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) + && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); + if (with_bias() && bia_dt != data_type::undef) + ok = ok && _bia_md()->data_type == bia_dt; + return ok; + } + +private: + const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } + const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } + const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } + const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } +}; + +struct convolution_fwd_pd_t: public convolution_pd_t { + typedef convolution_fwd_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_fwd_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t wei_tag, format_tag_t dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); + } +}; + +struct convolution_bwd_data_pd_t: public convolution_pd_t { + typedef convolution_bwd_data_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_data_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + + virtual bool support_bias() const { return false; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t diff_src_tag, + format_tag_t wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(diff_src_md_, diff_src_tag, + weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); + } +}; + +struct convolution_bwd_weights_pd_t: public convolution_pd_t { + typedef convolution_bwd_weights_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_weights_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, + diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp new file mode 100644 index 0000000000..98063c1c37 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp @@ -0,0 +1,188 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, deconvolution_direct, deconvolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) + return invalid_arguments; + + if (padding_r == nullptr) + padding_r = padding_l; + + auto dd = deconvolution_desc_t(); + dd.primitive_kind = primitive_kind::deconvolution; + dd.prop_kind = prop_kind; + dd.alg_kind = alg_kind; + + dd.diff_src_desc = dd.src_desc = zero_md(); + dd.diff_dst_desc = dd.dst_desc = zero_md(); + dd.diff_weights_desc = dd.weights_desc = zero_md(); + dd.diff_bias_desc = dd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias + = bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc; + (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc) + = *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc) + = *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(dd.strides, strides, sp_dims); + utils::array_copy(dd.padding[0], padding_l, sp_dims); + utils::array_copy(dd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(dd.dilates, dilates, sp_dims); + else + utils::array_set(dd.dilates, 0, sp_dims); + + dd.padding_kind = padding_kind; + dd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + bool consistency = true + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = dd.dilates[i - 2]; + int pad = padding_l[i - 2] + padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + consistency + = consistency && (dst - ker_range + pad) / str + 1 == src; + } + if (!consistency) + return invalid_arguments; + + *deconv_desc = dd; + return success; +} +} + +status_t mkldnn_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp new file mode 100644 index 0000000000..539e44bd9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp @@ -0,0 +1,293 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DECONVOLUTION_PD_HPP +#define DECONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct deconvolution_fwd_pd_t; + +struct deconvolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::deconvolution; + + deconvolution_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const deconvolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits<base_pkind>::query_d: + *(const deconvolution_desc_t **)result = desc(); + break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ + + dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; } + + dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; } + dim_t G() const + { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; } + + dim_t ID() const { + return ndims() >= 5 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1]; + } + + dim_t OD() const { + return ndims() >= 5 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1]; + } + + dim_t KD() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 5 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1; + } + dim_t KH() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 4 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1; + } + dim_t KW() const { + const int w_ndims = ndims() + with_groups(); + return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1]; + } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + bool with_bias() const { + return + !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero(); + } + + bool with_groups() const + { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; } + + int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + deconvolution_desc_t desc_; + const deconvolution_fwd_pd_t *hint_fwd_pd_; +}; + +struct deconvolution_fwd_pd_t: public deconvolution_pd_t { + typedef deconvolution_fwd_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_fwd_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; +}; + +struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_data_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_data_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; +}; + +struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_weights_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_weights_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp new file mode 100644 index 0000000000..f1708fca52 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float alpha, float beta) { + bool args_ok = true + && !any_null(eltwise_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data) + && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ed = eltwise_desc_t(); + ed.primitive_kind = primitive_kind::eltwise; + ed.prop_kind = prop_kind; + ed.alg_kind = alg_kind; + + ed.data_desc = *data_desc; + ed.diff_data_desc = + (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md(); + + ed.alpha = alpha; + ed.beta = beta; + + bool consistency = true + && IMPLICATION(ed.prop_kind == backward_data, + array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims, + ed.diff_data_desc.ndims)); + if (!consistency) return invalid_arguments; + + *eltwise_desc = ed; + return success; +} +} + +status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, float alpha, float beta) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc, + nullptr, alpha, beta); +} + +status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_data_desc, + const memory_desc_t *data_desc, float alpha, float beta) { + return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc, + diff_data_desc, alpha, beta); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp new file mode 100644 index 0000000000..9fd260fcee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ELTWISE_PD_HPP +#define ELTWISE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct eltwise_fwd_pd_t; + +struct eltwise_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::eltwise; + + eltwise_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const eltwise_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::eltwise_d: + *(const eltwise_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common eltwise aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + eltwise_desc_t desc_; + const eltwise_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct eltwise_fwd_pd_t: public eltwise_pd_t { + typedef eltwise_fwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const + { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); } +}; + +struct eltwise_bwd_pd_t: public eltwise_pd_t { + typedef eltwise_bwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_bwd_pd_t(engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const { return true; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp new file mode 100644 index 0000000000..3b3e25456d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include "engine.hpp" +#include "nstl.hpp" + +#include "c_types_map.hpp" +#include "../cpu/cpu_engine.hpp" + +namespace mkldnn { +namespace impl { + +engine_factory_t *engine_factories[] = { + &cpu::engine_factory, + nullptr, +}; + +static inline engine_factory_t *get_engine_factory(engine_kind_t kind) { + for (engine_factory_t **ef = engine_factories; *ef; ef++) + if ((*ef)->kind() == kind) + return *ef; + return nullptr; +} + +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +size_t mkldnn_engine_get_count(engine_kind_t kind) { + engine_factory_t *ef = get_engine_factory(kind); + return ef != nullptr ? ef->count() : 0; +} + +status_t mkldnn_engine_create(engine_t **engine, + engine_kind_t kind, size_t index) { + if (engine == nullptr) + return invalid_arguments; + + engine_factory_t *ef = get_engine_factory(kind); + if (ef == nullptr || index >= ef->count()) + return invalid_arguments; + + return ef->engine_create(engine, index); +} + +status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) { + if (engine == nullptr) + return invalid_arguments; + *kind = engine->kind(); + return success; +} + +status_t mkldnn_engine_destroy(engine_t *engine) { + /* TODO: engine->dec_ref_count(); */ + delete engine; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp new file mode 100644 index 0000000000..8ac8a29de5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ENGINE_HPP +#define ENGINE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive.hpp" +#include "utils.hpp" + +/** \brief An abstraction of an execution unit with shared resources + * + * Responsibilities: + * - Provide engine specific memory allocation + * - Provide engine specific primitive_desc_t creators + */ +struct mkldnn_engine: public mkldnn::impl::c_compatible { + mkldnn_engine(mkldnn::impl::engine_kind_t kind) + : kind_(kind) + {} + virtual ~mkldnn_engine() {} + + /** get kind of the current engine */ + virtual mkldnn::impl::engine_kind_t kind() const { return kind_; } + + /** allocate memory */ + virtual mkldnn::impl::status_t memory_create( + mkldnn::impl::memory_t **memory, + const mkldnn::impl::memory_desc_t *md, + void *handle) = 0; + + /** implementation section (typedefs) */ + + // TODO: remove engine? + typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)( + mkldnn::impl::reorder_pd_t **reorder_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *src_engine, + const mkldnn::impl::memory_desc_t *src_md, + mkldnn::impl::engine_t *dst_engine, + const mkldnn::impl::memory_desc_t *dst_md); + + typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)( + mkldnn::impl::concat_pd_t **concat_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, int concat_dim, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)( + mkldnn::impl::sum_pd_t **sum_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, const float *scales, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*primitive_desc_create_f)( + mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *); + + /* implementation section */ + + /** return the list of reorder implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const = 0; + + /** return the list of concat implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const = 0; + + /** return the list of sum implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const = 0; + + /** return the list of implementations. engine guarantees to return a + * NULL-terminated list */ + virtual const primitive_desc_create_f* get_implementation_list() const = 0; + +protected: + mkldnn::impl::engine_kind_t kind_; +}; + +namespace mkldnn { +namespace impl { + +struct engine_factory_t: public c_compatible { + virtual size_t count() const = 0; + virtual engine_kind_t kind() const = 0; + virtual status_t engine_create(engine_t **engine, size_t index) const = 0; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp new file mode 100644 index 0000000000..5a9f58cb1e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc); + if (!args_ok) return invalid_arguments; + + auto id = inner_product_desc_t(); + id.primitive_kind = primitive_kind::inner_product; + id.prop_kind = prop_kind; + + id.diff_src_desc = id.src_desc = zero_md(); + id.diff_dst_desc = id.dst_desc = zero_md(); + id.diff_weights_desc = id.weights_desc = zero_md(); + id.diff_bias_desc = id.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + + (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc; + (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) = + *bias_desc; + + id.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && one_of(src_desc->ndims, 2, 3, 4, 5) + && dst_desc->ndims == 2 + && weights_desc->ndims == src_desc->ndims + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && array_cmp(&src_desc->dims[1], &weights_desc->dims[1], + src_desc->ndims - 1) + && dst_desc->dims[1] == weights_desc->dims[0]; + if (!consistency) return invalid_arguments; + + *ip_desc = id; + return success; +} +} + +status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc, + prop_kind_t prop_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc, + dst_desc); +} + +status_t mkldnn_inner_product_backward_data_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc) +{ + return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc, + nullptr, diff_dst_desc); +} + +status_t mkldnn_inner_product_backward_weights_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *src_desc, + const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc) { + return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc, + diff_bias_desc, diff_dst_desc); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp new file mode 100644 index 0000000000..091cf0f5d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "inner_product_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); } +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); } +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); } +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp new file mode 100644 index 0000000000..c426de632c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp @@ -0,0 +1,321 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef INNER_PRODUCT_PD_HPP +#define INNER_PRODUCT_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); + +struct inner_product_fwd_pd_t; + +struct inner_product_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::inner_product; + + inner_product_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const inner_product_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::inner_product_d: + *(const inner_product_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common inner_product aux functions */ + + dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } + dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } + + dim_t ID() const { + return ndims() >= 5 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return ndims() >= 3 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t OD() const { + return ndims() >= 5 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return ndims() >= 3 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t KD() const { + return ndims() >= 5 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t KH() const { + return ndims() >= 4 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t KW() const { + return ndims() >= 3 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t IC_total() const { + return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], + ndims() - 1); + } + + dim_t IC_total_padded() const { + auto src_d = desc()->prop_kind == prop_kind::backward_data + ? memory_desc_wrapper(diff_src_md()) + : memory_desc_wrapper(src_md()); + assert(src_d.is_blocking_desc()); + if (!src_d.is_blocking_desc()) return -1; + return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); + } + + int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } + + bool with_bias() const + { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + inner_product_desc_t desc_; + const inner_product_fwd_pd_t *hint_fwd_pd_; + + status_t template_set_default_params(memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t *bias_md) { + using namespace format_tag; + if (src_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, + utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); + } + if (dst_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md, nc)); + if (weights_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, + utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); + } + if (bias_md && bias_md->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(*bias_md, x)); + return status::success; + } +}; + +struct inner_product_fwd_pd_t: public inner_product_pd_t { + typedef inner_product_fwd_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_fwd_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, weights_md_, dst_md_, + &bias_md_); + } +}; + +struct inner_product_bwd_data_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_data_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_data_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(diff_src_md_, weights_md_, + diff_dst_md_, nullptr); + } +}; + +struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_weights_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_weights_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, diff_weights_md_, + diff_dst_md_, &diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp new file mode 100644 index 0000000000..fcf18b556f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t lrn_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc, + dim_t local_size, float alpha, float beta, float k) { + bool args_ok = true + && !any_null(lrn_desc, data_desc) + && one_of(alg_kind, lrn_within_channel, lrn_across_channels) + && one_of(prop_kind, forward_training, forward_inference, backward_data) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ld = lrn_desc_t(); + ld.primitive_kind = primitive_kind::lrn; + ld.prop_kind = prop_kind; + ld.alg_kind = alg_kind; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + ld.data_desc = *data_desc; + if (!is_fwd) + ld.diff_data_desc = *diff_data_desc; + else + ld.diff_data_desc = zero_md(); + ld.local_size = local_size; + ld.lrn_alpha = alpha; + ld.lrn_beta = beta; + ld.lrn_k = k; + + bool consistency = true + && ld.data_desc.ndims == 4; + if (ld.prop_kind == backward_data) + consistency = consistency + && ld.diff_data_desc.ndims == 4 + && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4); + if (!consistency) return invalid_arguments; + + *lrn_desc = ld; + return success; +} +} + +status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, dim_t local_size, float alpha, + float beta, float k) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr, + local_size, alpha, beta, k); +} + +status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, dim_t local_size, float alpha, + float beta, float k) { + return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc, + diff_data_desc, local_size, alpha, beta, k); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp new file mode 100644 index 0000000000..90886e9656 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef LRN_PD_HPP +#define LRN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct lrn_fwd_pd_t; + +struct lrn_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::lrn; + + lrn_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , ws_md_() + {} + + const lrn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::lrn_d: + *(const lrn_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common lrn aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + lrn_desc_t desc_; + const lrn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t ws_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct lrn_fwd_pd_t: public lrn_pd_t { + typedef lrn_fwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_fwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct lrn_bwd_pd_t: public lrn_pd_t { + typedef lrn_bwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_bwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp new file mode 100644 index 0000000000..3fddc0bd45 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MATH_UTILS_HPP +#define MATH_UTILS_HPP + +#include <stdint.h> +#include <math.h> + +#include "utils.hpp" +#include "nstl.hpp" +#include "mkldnn_traits.hpp" + +#if defined(MKLDNN_X86_64) +#include "immintrin.h" +#endif + +namespace mkldnn { +namespace impl { +namespace math { + +/** rounds @p f to an integer according to the mxcsr register */ +inline int mxcsr_round(float f) { +#if defined(MKLDNN_X86_64) + return _mm_cvtss_si32(_mm_load_ss(&f)); +#else + return (int)nearbyintf(f); // optimism +#endif +} + +template <typename data_t, typename acc_t> +inline typename utils::enable_if<!nstl::is_integral<data_t>::value, + typename utils::remove_reference<data_t>::type>::type +saturate(const acc_t &x) { + return (typename utils::remove_reference<data_t>::type)x; +} + +template <typename data_t, typename acc_t> +inline typename utils::enable_if<nstl::is_integral<data_t>::value, + typename utils::remove_reference<data_t>::type>::type +saturate(const acc_t &x) { + acc_t v = x; + if (v < (acc_t)nstl::numeric_limits<data_t>::lowest()) + v = (acc_t)nstl::numeric_limits<data_t>::lowest(); + if (v > (acc_t)nstl::numeric_limits<data_t>::max()) + v = (acc_t)nstl::numeric_limits<data_t>::max(); + return (typename utils::remove_reference<data_t>::type)v; +} + +template <typename data_t> +double saturate(const double &x) { + double v = x; + if (v < (double)nstl::numeric_limits<data_t>::lowest()) + v = (double)nstl::numeric_limits<data_t>::lowest(); + if (v > (double)nstl::numeric_limits<data_t>::max()) + v = (double)nstl::numeric_limits<data_t>::max(); + return v; +} + +template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) { + return x <= 127u ? x : 127; +} + +template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) { + return x >= 0 ? x : 0; +} + +template <typename out_t> +typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type +out_round(float v) { return (out_t)mxcsr_round(v); } + +template <typename out_t> +typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type +out_round(double v) { return (out_t)mxcsr_round((float)v); } + +template <typename out_t> +typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type +out_round(float v) { return v; } + +inline int gcd(int a, int b) { + a = impl::nstl::abs(a); + b = impl::nstl::abs(b); + if (a < b) { int x = a; a = b; b = x; } + + if (b == 0) return a; + + int r; + while ((r = a % b) != 0) { a = b; b = r; } + + return b; +} + +template <typename T> +inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } + +/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ +inline int ilog2q(size_t v) { + if (v == 0) + return -1; + + int p = 0; +# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) + CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); +# undef CP + return p; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U one_m_square(T x) { + return (U)(1 - x) * (1 + x); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U x_m_square(T x) { + return (U)(1 - x) * x; +} + +/* activation */ +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U relu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(s * alpha); +} +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U relu_bwd(T dd, T s, A alpha) { + return s > 0 ? dd : (U)(dd * alpha); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U tanh_fwd(T s) { + const float e = tanhf((float) s); + return (U)e; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U tanh_bwd(T dd, T s) { + const float e = tanh_fwd<float>((float) s); + return (U)(dd * (1 - e) * (1 + e)); +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U elu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); +} +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> + inline U elu_bwd(T dd, T s, A alpha) { + return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U square_fwd(T s) { + return s * s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U square_bwd(T dd, T s) { + return dd * 2 * s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U abs_fwd(T s) { + return s > 0 ? s : -s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U abs_bwd(T dd, T s) { + return s > 0 ? dd : s < 0 ? -dd : 0; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U sqrt_fwd(T s) { + return s > 0 ? (U)(::sqrtf((float)(s))) : 0; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U sqrt_bwd(T dd, T s) { + return s > 0 + ? (U)(dd / (2 * ::sqrtf((float)(s)))) + : 0; +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U linear_fwd(T s, A alpha, A beta) { + return (U)(alpha * s + beta); +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U linear_bwd(T dd, T s, A alpha, A beta) { + (void) s; + (void) beta; + return (U)(dd * alpha); +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U bounded_relu_fwd(T s, A alpha) { + s = s > 0 ? s : 0; + return s > alpha ? (U)(alpha) : s; +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U bounded_relu_bwd(T dd, T s, A alpha) { + return dd * (0 < s && s < alpha ? 1 : 0); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U soft_relu_fwd(T s) { + float max_logf = 8.872284e+01; //::logf(FLT_MAX) + return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U soft_relu_bwd(T dd, T s) { + return (U)(dd / (1 + ::expf((float)(-s)))); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U logistic_fwd(T s) { + U v = (U)(::expf((float) -s)); + return 1 / (1 + v); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U logistic_bwd(T dd, T s) { + U v = logistic_fwd<T, U>(s); + return dd * v * (1 - v); +} + +inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { + using namespace alg_kind; + using namespace utils; + const bool preserves_zero = true + && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); + return preserves_zero; +} + +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) +{ + if (!bias) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits<dt>::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp new file mode 100644 index 0000000000..cea849c96e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::data_type; + +namespace { +bool memory_desc_sanity_check(int ndims,const dims_t dims, + data_type_t data_type, format_kind_t format_kind) { + if (ndims == 0) return true; + + bool ok = true + && dims != nullptr + && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS + && one_of(data_type, f32, s32, s8, u8) + && format_kind != format_kind::undef; + if (!ok) return false; + for (int d = 0; d < ndims; ++d) + if (dims[d] < 0) return false; + + return true; +} + +bool memory_desc_sanity_check(const memory_desc_t *md) { + if (md == nullptr) return false; + return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, + format_kind::any); +} +} + +status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, + const dims_t dims, data_type_t data_type, format_tag_t tag) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0 || tag == format_tag::undef) { + *memory_desc = types::zero_md(); + return success; + } + + format_kind_t format_kind = types::format_tag_to_kind(tag); + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind; + + status_t status = success; + if (tag == format_tag::undef) { + status = invalid_arguments; + } else if (tag == format_tag::any) { + // nop + } else if (format_kind == format_kind::blocked) { + status = memory_desc_wrapper::compute_blocking(md, tag); + } else { + assert(!"unreachable"); + status = invalid_arguments; + } + + if (status == success) + *memory_desc = md; + + return status; +} + +status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, + int ndims, const dims_t dims, data_type_t data_type, + const dims_t strides) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0) { + *memory_desc = types::zero_md(); + return success; + } + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind::blocked; + + dims_t default_strides = {0}; + if (strides == nullptr) { + default_strides[md.ndims - 1] = 1; + for (int d = md.ndims - 2; d >= 0; --d) + default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; + strides = default_strides; + } else { + /* TODO: add sanity check for the provided strides */ + } + + array_copy(md.format_desc.blocking.strides, strides, md.ndims); + + *memory_desc = md; + + return status::success; +} + +status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, + const memory_desc_t *parent_md, const dims_t dims, + const dims_t offsets) { + if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) + return invalid_arguments; + + const memory_desc_wrapper src_d(parent_md); + + for (int d = 0; d < src_d.ndims(); ++d) { + if (dims[d] < 0 || offsets[d] < 0 + || (offsets[d] + dims[d] > src_d.dims()[d])) + return invalid_arguments; + } + + if (src_d.format_kind() != format_kind::blocked) + return unimplemented; + + dims_t blocks; + src_d.compute_blocks(blocks); + + memory_desc_t dst_d = *parent_md; + auto &dst_d_blk = dst_d.format_desc.blocking; + + /* TODO: put this into memory_desc_wrapper */ + for (int d = 0; d < src_d.ndims(); ++d) { + /* very limited functionality for now */ + const bool ok = true + && offsets[d] % blocks[d] == 0 /* [r1] */ + && src_d.padded_offsets()[d] == 0 + && (false + || dims[d] % blocks[d] == 0 + || dims[d] < blocks[d]); + if (!ok) + return unimplemented; + + const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; + + dst_d.dims[d] = dims[d]; + dst_d.padded_dims[d] = is_right_border + ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; + dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; + dst_d.offset0 += /* [r1] */ + offsets[d] / blocks[d] * dst_d_blk.strides[d]; + } + + *md = dst_d; + + return success; +} + +int mkldnn_memory_desc_equal(const memory_desc_t *lhs, + const memory_desc_t *rhs) { + if (lhs == rhs) return 1; + if (any_null(lhs, rhs)) return 0; + return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); +} + +size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { + if (md == nullptr) return 0; + return memory_desc_wrapper(*md).size(); +} + +status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, + engine_t *engine, void *handle) { + if (any_null(memory, engine)) return invalid_arguments; + memory_desc_t z_md = types::zero_md(); + return engine->memory_create(memory, md ? md : &z_md, handle); +} + +status_t mkldnn_memory_get_memory_desc(const memory_t *memory, + const memory_desc_t **md) { + if (any_null(memory, md)) return invalid_arguments; + *md = memory->md(); + return success; +} + +status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { + if (any_null(memory, engine)) return invalid_arguments; + *engine = memory->engine(); + return success; +} + +status_t mkldnn_memory_get_data_handle(const memory_t *memory, + void **handle) { + if (any_null(handle)) + return invalid_arguments; + if (memory == nullptr) { + *handle = nullptr; + return success; + } + return memory->get_data_handle(handle); +} + +status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { + if (any_null(memory)) return invalid_arguments; + return memory->set_data_handle(handle); +} + +status_t mkldnn_memory_destroy(memory_t *memory) { + delete memory; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp new file mode 100644 index 0000000000..03dfee01ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_HPP +#define MEMORY_HPP + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" + +struct mkldnn_memory: public mkldnn::impl::c_compatible { + mkldnn_memory(mkldnn::impl::engine_t *engine, + const mkldnn::impl::memory_desc_t *md) + : engine_(engine), md_(*md) {} + virtual ~mkldnn_memory() {} + + /** allocates/initializes memory */ + virtual mkldnn::impl::status_t init() = 0; + + /** returns memory's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + /** returns memory's description */ + const mkldnn::impl::memory_desc_t *md() const { return &md_; } + + /** returns data handle */ + virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0; + + /** sets data handle */ + virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0; + + /** zeros padding */ + virtual mkldnn::impl::status_t zero_pad() const + { return mkldnn::impl::status::success; } + +protected: + mkldnn::impl::engine_t *engine_; + const mkldnn::impl::memory_desc_t md_; + +private: + mkldnn_memory() = delete; + mkldnn_memory(const mkldnn_memory &) = delete; + mkldnn_memory(mkldnn_memory &&) = delete; + mkldnn_memory &operator=(const mkldnn_memory &) = delete; + mkldnn_memory &operator=(mkldnn_memory &&) = delete; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp new file mode 100644 index 0000000000..8a99be33f3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp @@ -0,0 +1,212 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include <initializer_list> + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t fill_blocked(memory_desc_t &md, + std::initializer_list<int> perm, + std::initializer_list<int> inner_blks, + std::initializer_list<int> inner_idxs) { + const bool ok = true + && perm.size() == (size_t)md.ndims + && inner_blks.size() == inner_idxs.size(); + if (!ok) return status::invalid_arguments; + + md.offset0 = 0; + + blocking_desc_t &blk = md.format_desc.blocking; + + dim_t block_size = 1; + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + + blk.inner_nblks = (int)inner_blks.size(); + + int iblk = 0; + for (const auto &b: inner_idxs) + blk.inner_idxs[iblk++] = b; + + iblk = 0; + for (const auto &b: inner_blks) { + int dim = blk.inner_idxs[iblk]; + block_size *= b; + blocks[dim] *= b; + blk.inner_blks[iblk++] = b; + } + + utils::array_set(md.padded_offsets, 0, md.ndims); + for (int d = 0; d < md.ndims; ++d) + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + + dim_t stride = block_size; + // if only we use C++14, the initializer_list would have rbegin()/rend()... + for (int d = 0; d < md.ndims; ++d) + stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; + + for (const auto &d: perm) { + if (md.padded_dims[d] == 0) { + blk.strides[d] = 1; + continue; + } + stride /= md.padded_dims[d] / blocks[d]; + blk.strides[d] = stride; + } + + assert(stride == block_size); + + return status::success; +} + +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag) +{ + using namespace format_tag; + + if (memory_desc.ndims == 0) return status::invalid_arguments; + +# define C(tag, ... /* perm, inner_blks, inner_idxs */) \ + case tag: return fill_blocked(memory_desc, __VA_ARGS__) + + switch (tag) { + C(a, {0}, {}, {}); + C(ab, {0, 1}, {}, {}); + C(abc, {0, 1, 2}, {}, {}); + C(abcd, {0, 1, 2, 3}, {}, {}); + C(abcde, {0, 1, 2, 3, 4}, {}, {}); + C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); + C(abdec, {0, 1, 3, 4, 2}, {}, {}); + C(acb, {0, 2, 1}, {}, {}); + C(acbde, {0, 2, 1, 3, 4}, {}, {}); + C(acdb, {0, 2, 3, 1}, {}, {}); + C(acdeb, {0, 2, 3, 4, 1}, {}, {}); + C(ba, {1, 0}, {}, {}); + C(bac, {1, 0, 2}, {}, {}); + C(bacd, {1, 0, 2, 3}, {}, {}); + C(bcda, {1, 2, 3, 0}, {}, {}); + C(cba, {2, 1, 0}, {}, {}); + C(cdba, {2, 3, 1, 0}, {}, {}); + C(cdeba, {2, 3, 4, 1, 0}, {}, {}); + C(decab, {3, 4, 2, 0, 1}, {}, {}); + + C(Abc4a, {0, 1, 2}, {4}, {0}); + C(aBc4b, {0, 1, 2}, {4}, {1}); + C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); + C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); + C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); + C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); + C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); + C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); + C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); + C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); + C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); + C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); + C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); + C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); + C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); + C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); + C(Acb4a, {0, 2, 1}, {4}, {0}); + C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); + C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); + + C(Abc16a, {0, 1, 2}, {16}, {0}); + C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); + C(aBc16b, {0, 1, 2}, {16}, {1}); + C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); + C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); + C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); + C(aBc8b, {0, 1, 2}, {8}, {1}); + C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); + C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); + C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); + C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); + C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); + C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); + C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); + C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); + C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); + C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); + C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); + C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); + C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); + C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); + C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); + C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); + C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); + C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); + C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); + C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); + C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); + C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); + C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); + C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); + C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); + C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); + C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); + C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); + C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); + C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); + C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); + C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); + C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); + C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); + C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); + C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); + C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); + C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); + C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); + C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); + C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); + C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); + C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); + C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); + C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); + C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); + C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); + C(Acb16a, {0, 2, 1}, {16}, {0}); + C(Acb8a, {0, 2, 1}, {8}, {0}); + C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); + C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); + C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); + C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); + C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); + C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); + C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); + C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); + default: break; + } + +#undef C + + return status::invalid_arguments; +} + +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp new file mode 100644 index 0000000000..1758f9078a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_DESC_WRAPPER_HPP +#define MEMORY_DESC_WRAPPER_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +/** thin wrapper class over \struct memory_desc_t which allows easy + * manipulations with underlying C structure, which is taken by reference */ +struct memory_desc_wrapper: public c_compatible { + const memory_desc_t *md_; + + /** constructor which takes a reference to a constant underlying C memory + * descriptor \param md */ + memory_desc_wrapper(const memory_desc_t *md): md_(md) {} + memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {} + + /* implementing attributes */ + int ndims() const { return md_->ndims; } + const dims_t &dims() const { return md_->dims; } + data_type_t data_type() const { return md_->data_type; } + + const dims_t &padded_dims() const { return md_->padded_dims; } + const dims_t &padded_offsets() const { return md_->padded_offsets; } + dim_t offset0() const { return md_->offset0; } + + format_kind_t format_kind() const { return md_->format_kind; } + + bool is_blocking_desc() const + { return format_kind() == format_kind::blocked; } + bool is_wino_desc() const + { return format_kind() == format_kind::wino; } + bool is_rnn_packed_desc() const + { return format_kind() == format_kind::rnn_packed; } + + const blocking_desc_t &blocking_desc() const { + assert(is_blocking_desc()); + return md_->format_desc.blocking; + } + const wino_desc_t &wino_desc() const { + assert(is_wino_desc()); + return md_->format_desc.wino_desc; + } + const rnn_packed_desc_t &rnn_packed_desc() const { + assert(is_rnn_packed_desc()); + return md_->format_desc.rnn_packed_desc; + } + + const memory_extra_desc_t &extra() const { return md_->extra; } + + /* some useful function */ + + /** returns the number of elements including padding if \param with_padding + * is true, and the number of data elements otherwise */ + dim_t nelems(bool with_padding = false) const { + if (is_zero()) return 0; + return utils::array_product( + with_padding ? padded_dims() : dims(), ndims()); + } + + /** returns true if memory descriptor is zero */ + bool is_zero() const { return ndims() == 0; } + + /** returns true if memory descriptor contains zero as one of its dim */ + bool has_zero_dim() const { return nelems() == 0; } + + /** return the size of data type (a shortcut) */ + size_t data_type_size() const + { return types::data_type_size(data_type()); } + + /** return the size of data type of additional buffer */ + size_t additional_buffer_data_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) + return sizeof(int32_t); + return 0; + } + + /** return true if memory format has additional buffer */ + bool is_additional_buffer() const { + return (extra().flags & memory_extra_flags::compensation_conv_s8s8); + } + + /** returns the size of additional buffer */ + size_t additional_buffer_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) { + int cmask = extra().compensation_mask; + assert(cmask == 1 || cmask == 3); + dim_t prod = 1; + for (int d = 0; d < ndims(); ++d) + if (cmask & (1<<d)) prod *= padded_dims()[d]; + return prod * additional_buffer_data_size(); + } + + return 0; + } + + /** returns the size required to store described memory + * note: if offset0 != 0 returns 0 (need to specify the behavior) */ + size_t size() const { + if (is_zero() || has_zero_dim() || format_kind() == format_kind::any) + return 0; + + if (format_kind() == format_kind::wino) { + return wino_desc().size; + } else if (format_kind() == format_kind::rnn_packed) { + return rnn_packed_desc().size; + } else { + if (offset0() != 0) return 0; + + dims_t blocks = {0}; + compute_blocks(blocks); + + const auto &bd = blocking_desc(); + + size_t max_size = 0; + for (int d = 0; d < ndims(); ++d) + max_size = nstl::max<size_t>(max_size, + padded_dims()[d] / blocks[d] * bd.strides[d]); + + if (max_size == 1 && bd.inner_nblks != 0) { + max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); + } + + return max_size * data_type_size() + additional_buffer_size(); + } + } + + /** returns true if data is dense in memory */ + bool is_dense(bool with_padding = false) const { + if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + return nelems(with_padding) * data_type_size() == size(); + } + + /** returns true if memory desc is fully defined */ + bool is_defined() const { return format_kind() != format_kind::any; } + + /** returns true if the only (potentially) padded dim is \param dim */ + bool only_padded_dim(int dim) const { + for (int d = 0; d < ndims(); ++d) + if (d != dim && dims()[d] != padded_dims()[d]) + return false; + return true; + } + + /** returns true if memory desc has blocked layout and block dims are 1s */ + bool is_plain() const { + if (!is_blocking_desc()) return false; + return blocking_desc().inner_nblks == 0; + } + + /** returns overall block sizes */ + void compute_blocks(dims_t blocks) const { + if (!is_blocking_desc()) { + utils::array_set(blocks, 0, ndims()); + return; + } + + utils::array_set(blocks, 1, ndims()); + + const auto &bd = blocking_desc(); + for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) + blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; + } + + /* comparison section */ + + bool operator==(const memory_desc_wrapper &rhs) const + { return *this->md_ == *rhs.md_; } + bool operator!=(const memory_desc_wrapper &rhs) const + { return !operator==(rhs); } + bool operator==(const memory_desc_t &rhs) const + { return operator==(memory_desc_wrapper(rhs)); } + bool operator!=(const memory_desc_t &rhs) const + { return !operator==(rhs); } + + /** returns true if data (w/o padding if with_padding == false and w/ + * padding otherwise) have the same physical structure, i.e. dimensions, + * strides, and blocked structure. Depending on with_data_type flag + * data_type is taken or not taken into account. dim_start allows to check + * similarity for the logical part of data [dim_start .. ndims()]. + * CAUTION: format kind any and undef are not similar to whatever, hence the + * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ + /* TODO: revise */ + bool similar_to(const memory_desc_wrapper &rhs, + bool with_padding = true, bool with_data_type = true, + int dim_start = 0) const; + + /** returns true if one memory can be reordered to another */ + bool consistent_with(const memory_desc_wrapper &rhs) const; + + /** returns true if the memory desc corresponds to the given format tag and + * strides. + * @sa memory_desc_matches_tag */ + bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { + return memory_desc_matches_tag(*md_, tag, strides); + } + + /** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ + template <typename... Tags> + format_tag_t matches_one_of_tag(Tags ...tags) const { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(*md_, tag)) + return tag; + } + return format_tag::undef; + } + + /* offset section */ + + /** returns physical offset by logical one. logical offset is represented by + * an array \param pos. if \param is_pos_padded is true \param pos + * represents the position in already padded area */ + dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + const blocking_desc_t &blk = blocking_desc(); + + dims_t pos_copy = {0}; + for (int d = 0; d < ndims(); ++d) + pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); + + dim_t phys_offset = offset0(); + + if (blk.inner_nblks > 0) { + dim_t blk_stride = 1; + for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { + const int d = blk.inner_idxs[iblk]; + const dim_t p = pos_copy[d] % blk.inner_blks[iblk]; + + phys_offset += p * blk_stride; + + pos_copy[d] /= blk.inner_blks[iblk]; + + blk_stride *= blk.inner_blks[iblk]; + } + } + + for (int d = 0; d < ndims(); ++d) { + const dim_t p = pos_copy[d]; + phys_offset += p * blk.strides[d]; + } + + return phys_offset; + } + + /** returns physical offset by logical one. logical offset is represented by + * a scalar \param l_offset. if \param is_pos_padded is true, \param + * l_offset represents logical offset in already padded area */ + dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + dims_t pos; + for (int rd = 0; rd < ndims(); ++rd) { + const int d = ndims() - 1 - rd; + const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d]; + pos[d] = l_offset % cur_dim; + l_offset /= cur_dim; + } + return off_v(pos, is_pos_padded); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) */ + template<typename... Args> + dim_t off(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, false); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) in already + * padded area */ + template<typename... Args> + dim_t off_padding(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, true); + } + + /** returns physical offset by logical one. Logical offset is represented by + * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a + * user responsibility to adjust the result to get offset within blocks */ + template<typename ...Args> + dim_t blk_off(Args... args) const { + return _blk_off<sizeof...(args), Args...>(args...); + } + + template<bool skip_first, typename T, typename ...Args> + dim_t blk_off(T xn, Args... args) const { + return skip_first + ? blk_off<Args...>(args...) + : blk_off<T, Args...>(xn, args...); + } + + /* static functions section */ + /* TODO: replace with non-static, once md_ becomes non-const ref */ + + static status_t compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag); + +private: + /* TODO: put logical_offset in utils */ + template<typename T> + dim_t logical_offset(T x0) const { return x0; } + + template<typename T, typename... Args> + dim_t logical_offset(T xn, Args... args) const { + const size_t n_args = sizeof...(args); + return xn * utils::array_product<n_args>( + &dims()[ndims() - n_args]) + logical_offset(args...); + } + + template<int ORIG_LEN, typename ...Void> + dim_t _blk_off() const { return offset0(); } + + template<int ORIG_LEN, typename T, typename ...Args> + dim_t _blk_off(T xc, Args ...args) const { + assert(is_blocking_desc()); + constexpr int dc = ORIG_LEN - sizeof...(args) - 1; + return xc * blocking_desc().strides[dc] + + _blk_off<ORIG_LEN, Args...>(args...); + } +}; + +inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, + bool with_padding, bool with_data_type, int dim_start) const { + using namespace utils; + + if (one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + if (is_wino_desc() || is_rnn_packed_desc()) + return false; + + const int ds = dim_start; + const auto &blk = blocking_desc(); + const auto &r_blk = rhs.blocking_desc(); + + return ndims() == rhs.ndims() + && dim_start <= ndims() /* guard */ + && format_kind() == rhs.format_kind() + && IMPLICATION(with_data_type, data_type() == rhs.data_type()) + && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) + && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && blk.inner_nblks == r_blk.inner_nblks + && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) + && IMPLICATION(with_padding, true + && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds, + ndims() - ds) + && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds, + ndims() - ds)); +} + +inline bool memory_desc_wrapper::consistent_with( + const memory_desc_wrapper &rhs) const { + if (ndims() == rhs.ndims()) { + for (int d = 0; d < ndims(); ++d) { + if (dims()[d] != rhs.dims()[d]) return false; + } + return true; + } else { + /* TODO: revise. + * is the following possible? + * [1, a, b] <--reorder--> [a, b] + * [a, 1, b] <--reorder--> [a, b] + * not, at least for now */ + return false; + } +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp new file mode 100644 index 0000000000..ec077b308c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp @@ -0,0 +1,295 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_TRACKING_HPP +#define MEMORY_TRACKING_HPP + +#include <assert.h> +#include <unordered_map> + +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace memory_tracking { + +/* Memory tracking capabilities + * + * The main purpose of this header file is to provide uniform way to register + * required memory for a scratchpad at a primitive descriptor creation time + * and then easily access it having only the base address of the scratchpad. + * + * Primitives might contain multiple disjoint parts that require temporary + * buffers (known as scratchpad) during their execution. A primitive descriptor + * should summarize all the needs into one single number -- the buffer size + * that would be requested from a user. At execution time, the corresponding + * primitive will receive a base pointer to a scratchpad. It then needs to + * provide each part of algorithm the corresponding piece of memory. Three main + * challenges here are: + * 1. Track correct offset (from the base scratchpad address) for each piece + * 2. Algorithm might require that different memory pieces to be aligned, so + * the scratchpad size is no more just a sum of size of the corresponding + * subparts. + * 3. While a primitive is responsible for its scratchpad, the implementation + * might use some other basic blocks (e.g. cpu_reducer) that also require + * scratchpad memory. So there should be a simple way of passing the + * information back and force between the main algorithm (a primitive) and + * auxiliary stuff that lives completely separately from it (e.g. reducer). + * + * To address these challenges this header file provides 3 structures: + * 1. registry_t -- the class the stores the information about requested + * memory. The information includes required size and desired + * alignment for each piece. This class is also responsible + * for computing the right offset to a given piece using the + * base pointer. + * This class is basically a ledger with all entries. + * Lives in primitive descriptors. + * + * 2. registrar_t -- the interface to a registry_t to book memory. Used at + * primitive descriptor creation time only. Contains a + * reference to the corresponding *mutable* registry. + * Always modifiable. + * Allows chaining (using prefixes). + * + * 3. grantor_t -- the interface to a registry_t to access memory. Used at + * primitive execution time only. Contains a reference to + * the corresponding *constant* registry and base pointer. + * Always constant. + * Allows chaining (using prefixes). + * + * Both registrar_t and grantor_t allow chaining with extra prefix provided. + * The feature is useful when a primitive offload a part of computations to + * some other primitives which require their own scratchpad space + * (e.g. reducer). Prefixes are used to avoid key collision in cases when + * multiple sub-primitive (e.g. multiple reducers) are used. + * + * A short example below demonstrates how to use aforementioned classes. In it + * the main primitive is convolution that uses scratchpad for keeping padded + * bias. It also needs a reducer, that needs its own space as well. + * + * ``` c++ + * struct reducer_t { + * static void init(registrar_t &scratchpad) { + * // preserve space for the reduction (one page aligned) + * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096); + * } + * + * void exec(const grantor_t &scratchpad) { + * // get the pointer to preserved space. scratchpad came from + * // upper primitive (convolution in this example) + * auto space = scratchpad.get<float>(key_reducer_space); + * + * space[:] += ...; + * } + * }; + * + * struct conv_t { + * struct pd_t { + * void init() { + * registrar_t scratchpad(scratchpad_registry_); + * + * // preserve a space for padded bias (using default alignment) + * scratchpad.book(key_conv_padded_bias, 128); + * + * // create a proxy registrar for the reducer All entries made + * // by reducer would live in convolution's registry, but would + * // have their own `prefix`, so no interference with conv's + * // buffers. + * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); + * + * reducer_t::init(reducer_scratchpad); + * } + * + * registry_t scratchpad_registry_; + * } + * + * void exec() { + * // get the base pointer to a scratchpad memory from a user + * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD); + * + * // create a grantor to the scratchpad (and provide the base + * // pointer). + * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); + * + * // access the padded_bias (need only key name and the grantor) + * auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); + * + * // to give the `right` grantor to reducer we need to add the + * // corresponding prefix, so that reducer would be able to access + * // its keys. The call is very similar to the one in pd_t::init + * // with only difference in types: grantor_t vs registrar_t. + * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); + * reducer->exec(reducer_scratchpad); + * } + * }; + * ``` + */ + + +/* namespace with common keys and prefixes */ +namespace names { +enum { + key_none = 0, + key_bnorm_tmp_mean, + key_bnorm_tmp_var, + key_bnorm_tmp_diff_ss, + key_bnorm_tmp_stats, + key_bnorm_reduction, + key_concat_iptrs, + key_concat_istrides, + key_concat_nelems, + key_concat_optrs, + key_conv_adjusted_scales, + key_conv_bia_reduction, + key_conv_gemm_col, + key_conv_gemm_imtr, + key_conv_int_dat_in_acc_dt, + key_conv_padded_bias, + key_conv_rtus_space, + key_conv_tr_diff_dst, + key_conv_tr_diff_dst_bctx, + key_conv_tr_src, + key_conv_tr_src_bctx, + key_conv_wei_reduction, + key_conv_wei_bia_reduction, + key_conv_wei_bia_reduction_bctx, + key_iprod_int_dat_in_acc_dt, + key_reducer_space, + key_reducer_space_bctx, + key_reorder_wino_plain, + key_reorder_wino_transform_space, + key_reorder_rnn_weights_quantization, + key_reorder_rnn_weights_reduction, + key_rnn_space, + key_rnn_ptrs_bia, + key_rnn_ptrs_wei_layer, + key_rnn_ptrs_wei_iter, + key_softmax_reduction, + key_wino_U, + key_wino_V, + key_wino_M, + key_barrier, +}; + +enum { + prefix_none = 0, + prefix_reducer_bia, + prefix_reducer_wei, +}; +} + +// level 0: 00 00 00 xxx +// level 1: 00 00 aa xxx +// level 2: 00 aa bb xxx +// level 3: aa bb cc xxx +// max # of levels: 3 + 1 (base_level) +// here: +// xxx : [1 .. MAX_KEY) : key +// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 + +using key_t = uint32_t; +enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), }; + +/// generates global key based on a prefix and a local key +inline key_t make_key(key_t prefix, key_t key) { return prefix + key; } + +/// generates global prefix based on the global parent and the local ones +inline key_t make_prefix(key_t parent_prefix, key_t prefix) +{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; } + +struct registrar_t; +struct grantor_t; + +struct registry_t { + void book(const key_t &key, size_t size, size_t alignment) { + if (size == 0) return; + assert(offset_map_.count(key) == 0); + + size = utils::rnd_up(size, minimal_alignment); + alignment = nstl::max<size_t>(alignment, minimal_alignment); + offset_map_[key] = entry_t{size_, size, alignment}; + + size_ += size + alignment - minimal_alignment; + } + + void *get(const key_t &key, void *base_ptr) const { + if (base_ptr == nullptr) { assert(size() == 0); return nullptr; } + if (offset_map_.count(key) != 1) return nullptr; + + const auto &e = offset_map_.at(key); + base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment); + char *ptr = (char *)base_ptr + e.offset; + return utils::align_ptr<void>(ptr, e.alignment); + } + + size_t size() const + { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; } + + registrar_t registrar(); + grantor_t grantor(void *base_ptr) const; + +protected: + enum { minimal_alignment = 64 }; + struct entry_t { size_t offset, size, alignment; }; + + std::unordered_map<key_t, entry_t> offset_map_; + size_t size_ = 0; +}; + +struct registrar_t { + enum { default_alignment = 64 }; + + registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {} + registrar_t(registrar_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) {} + + void book(const key_t &key, size_t size, + size_t alignment = default_alignment) + { registry_.book(make_key(prefix_, key), size, alignment); } + +protected: + registry_t ®istry_; + const key_t prefix_; +}; + +struct grantor_t { + grantor_t(const registry_t ®istry, void *base_ptr) + : registry_(registry), prefix_(0), base_ptr_(base_ptr) {} + grantor_t(const grantor_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) + , base_ptr_(parent.base_ptr_) {} + + template <typename T = void> T *get(const key_t &key) const + { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); } + +protected: + const registry_t ®istry_; + const key_t prefix_; + void *base_ptr_; +}; + +inline registrar_t registry_t::registrar() { return registrar_t(*this); } +inline grantor_t registry_t::grantor(void *base_ptr) const +{ return grantor_t(*this, base_ptr); } + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp new file mode 100644 index 0000000000..2ef4a8fddc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <stdio.h> +#include <cinttypes> + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#define DPRINT(...) do { \ + int l = snprintf(str + written_len, str_len, __VA_ARGS__); \ + if (l < 0) return l; \ + if ((size_t)l >= str_len) return -1; \ + written_len += l; str_len -= l; \ +} while(0) + +int mkldnn_md2fmt_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1u) + return -1; + + int written_len = 0; + + if (mdesc == nullptr) { + DPRINT("%s::%s::", + mkldnn_dt2str(data_type::undef), + mkldnn_fmt_kind2str(format_kind::undef)); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + DPRINT("%s:", mkldnn_dt2str(md.data_type())); + + bool padded_dims = false, padded_offsets = false; + for (int d = 0; d < md.ndims(); ++d) { + if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true; + if (md.padded_offsets()[d] != 0) padded_offsets = true; + } + bool offset0 = md.offset0(); + DPRINT("%s%s%s:", + padded_dims ? "p" : "", + padded_offsets ? "o" : "", + offset0 ? "0" : ""); + + DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind())); + + if (!md.is_blocking_desc()) { + /* TODO: extend */ + DPRINT("%s:", ""); + } else { + const auto &blk = md.blocking_desc(); + + dims_t blocks; + md.compute_blocks(blocks); + + char dim_chars[MKLDNN_MAX_NDIMS + 1]; + + bool plain = true; + for (int d = 0; d < md.ndims(); ++d) { + dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; + if (blocks[d] != 1) plain = false; + } + + dims_t strides; + utils::array_copy(strides, blk.strides, md.ndims()); + utils::simultaneous_sort(strides, dim_chars, md.ndims(), + [](dim_t a, dim_t b) { return b - a; }); + + dim_chars[md.ndims()] = '\0'; + DPRINT("%s", dim_chars); + + if (!plain) { + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + DPRINT("%d%c", (int)blk.inner_blks[iblk], + 'a' + (char)blk.inner_idxs[iblk]); + } + } + + DPRINT("%s", ":"); + } + + DPRINT("f%lx", (long)md.extra().flags); + + return written_len; +} + +int mkldnn_md2dim_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1) + return -1; + + int written_len = 0; + + if (mdesc == nullptr || mdesc->ndims == 0) { + DPRINT("%s", ""); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + for (int d = 0; d < md.ndims() - 1; ++d) + DPRINT("%" PRId64 "x", md.dims()[d]); + DPRINT("%" PRId64, md.dims()[md.ndims() - 1]); + + return written_len; +} + +#undef DPRINT diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp new file mode 100644 index 0000000000..16a8f7ea5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp @@ -0,0 +1,365 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#include <assert.h> + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +const char *mkldnn_status2str(mkldnn_status_t v) { + if (v == mkldnn_success) return "success"; + if (v == mkldnn_out_of_memory) return "out_of_memory"; + if (v == mkldnn_try_again) return "try_again"; + if (v == mkldnn_invalid_arguments) return "invalid_arguments"; + if (v == mkldnn_not_ready) return "not_ready"; + if (v == mkldnn_unimplemented) return "unimplemented"; + if (v == mkldnn_iterator_ends) return "iterator_ends"; + if (v == mkldnn_runtime_error) return "runtime_error"; + if (v == mkldnn_not_required) return "not_required"; + assert(!"unknown status"); + return "unknown status"; +} + +const char *mkldnn_dt2str(mkldnn_data_type_t v) { + if (v == mkldnn_data_type_undef) return "undef"; + if (v == mkldnn_f32) return "f32"; + if (v == mkldnn_s32) return "s32"; + if (v == mkldnn_s8) return "s8"; + if (v == mkldnn_u8) return "u8"; + assert(!"unknown dt"); + return "unknown dt"; +} + +const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) { + if (v == mkldnn_format_kind_undef) return "undef"; + if (v == mkldnn_format_kind_any) return "any"; + if (v == mkldnn_blocked) return "blocked"; + if (v == mkldnn_format_kind_wino) return "wino"; + if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed"; + assert(!"unknown fmt_kind"); + return "unknown fmt_kind"; +} + +const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) { + if (v == mkldnn_format_tag_undef) return "undef"; + if (v == mkldnn_format_tag_any) return "format_tag_any"; + if (v == mkldnn_a) return "a"; + if (v == mkldnn_ab) return "ab"; + if (v == mkldnn_abc) return "abc"; + if (v == mkldnn_abcd) return "abcd"; + if (v == mkldnn_abcde) return "abcde"; + if (v == mkldnn_abcdef) return "abcdef"; + if (v == mkldnn_abdec) return "abdec"; + if (v == mkldnn_acb) return "acb"; + if (v == mkldnn_acbde) return "acbde"; + if (v == mkldnn_acdb) return "acdb"; + if (v == mkldnn_acdeb) return "acdeb"; + if (v == mkldnn_ba) return "ba"; + if (v == mkldnn_bac) return "bac"; + if (v == mkldnn_bacd) return "bacd"; + if (v == mkldnn_bcda) return "bcda"; + if (v == mkldnn_cba) return "cba"; + if (v == mkldnn_cdba) return "cdba"; + if (v == mkldnn_cdeba) return "cdeba"; + if (v == mkldnn_decab) return "decab"; + if (v == mkldnn_Abc16a) return "Abc16a"; + if (v == mkldnn_ABc16a16b) return "ABc16a16b"; + if (v == mkldnn_aBc16b) return "aBc16b"; + if (v == mkldnn_ABc16b16a) return "ABc16b16a"; + if (v == mkldnn_Abc4a) return "Abc4a"; + if (v == mkldnn_aBc4b) return "aBc4b"; + if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b"; + if (v == mkldnn_ABc4b4a) return "ABc4b4a"; + if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a"; + if (v == mkldnn_ABc8a8b) return "ABc8a8b"; + if (v == mkldnn_aBc8b) return "aBc8b"; + if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b"; + if (v == mkldnn_ABc8b8a) return "ABc8b8a"; + if (v == mkldnn_Abcd16a) return "Abcd16a"; + if (v == mkldnn_ABcd16a16b) return "ABcd16a16b"; + if (v == mkldnn_aBcd16b) return "aBcd16b"; + if (v == mkldnn_ABcd16b16a) return "ABcd16b16a"; + if (v == mkldnn_aBCd16b16c) return "aBCd16b16c"; + if (v == mkldnn_aBCd16c16b) return "aBCd16c16b"; + if (v == mkldnn_Abcd4a) return "Abcd4a"; + if (v == mkldnn_aBcd4b) return "aBcd4b"; + if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b"; + if (v == mkldnn_ABcd4b4a) return "ABcd4b4a"; + if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c"; + if (v == mkldnn_aBCd4c4b) return "aBCd4c4b"; + if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a"; + if (v == mkldnn_ABcd8a8b) return "ABcd8a8b"; + if (v == mkldnn_aBcd8b) return "aBcd8b"; + if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b"; + if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b"; + if (v == mkldnn_ABcd8b8a) return "ABcd8b8a"; + if (v == mkldnn_aBCd8b8c) return "aBCd8b8c"; + if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c"; + if (v == mkldnn_aBCd8c8b) return "aBCd8c8b"; + if (v == mkldnn_Abcde16a) return "Abcde16a"; + if (v == mkldnn_ABcde16a16b) return "ABcde16a16b"; + if (v == mkldnn_aBcde16b) return "aBcde16b"; + if (v == mkldnn_ABcde16b16a) return "ABcde16b16a"; + if (v == mkldnn_aBCde16b16c) return "aBCde16b16c"; + if (v == mkldnn_aBCde16c16b) return "aBCde16c16b"; + if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c"; + if (v == mkldnn_Abcde4a) return "Abcde4a"; + if (v == mkldnn_aBcde4b) return "aBcde4b"; + if (v == mkldnn_ABcde4b4a) return "ABcde4b4a"; + if (v == mkldnn_aBCde4b4c) return "aBCde4b4c"; + if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c"; + if (v == mkldnn_aBCde4c4b) return "aBCde4c4b"; + if (v == mkldnn_Abcde8a) return "Abcde8a"; + if (v == mkldnn_ABcde8a8b) return "ABcde8a8b"; + if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b"; + if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b"; + if (v == mkldnn_ABcde8b8a) return "ABcde8b8a"; + if (v == mkldnn_aBCde8b8c) return "aBCde8b8c"; + if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c"; + if (v == mkldnn_aBCde8c8b) return "aBCde8c8b"; + if (v == mkldnn_aBcdef16b) return "aBcdef16b"; + if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c"; + if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b"; + if (v == mkldnn_aBcdef4b) return "aBcdef4b"; + if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b"; + if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c"; + if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c"; + if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b"; + if (v == mkldnn_aBdc16b) return "aBdc16b"; + if (v == mkldnn_aBdc4b) return "aBdc4b"; + if (v == mkldnn_aBdc8b) return "aBdc8b"; + if (v == mkldnn_aBdec16b) return "aBdec16b"; + if (v == mkldnn_aBdec4b) return "aBdec4b"; + if (v == mkldnn_aBdec8b) return "aBdec8b"; + if (v == mkldnn_aBdefc16b) return "aBdefc16b"; + if (v == mkldnn_aBdefc4b) return "aBdefc4b"; + if (v == mkldnn_aBdefc8b) return "aBdefc8b"; + if (v == mkldnn_Acb16a) return "Acb16a"; + if (v == mkldnn_Acb4a) return "Acb4a"; + if (v == mkldnn_Acb8a) return "Acb8a"; + if (v == mkldnn_aCBd16b16c) return "aCBd16b16c"; + if (v == mkldnn_aCBde16b16c) return "aCBde16b16c"; + if (v == mkldnn_Acdb16a) return "Acdb16a"; + if (v == mkldnn_Acdb4a) return "Acdb4a"; + if (v == mkldnn_Acdb8a) return "Acdb8a"; + if (v == mkldnn_Acdeb16a) return "Acdeb16a"; + if (v == mkldnn_Acdeb4a) return "Acdeb4a"; + if (v == mkldnn_Acdeb8a) return "Acdeb8a"; + if (v == mkldnn_BAc16a16b) return "BAc16a16b"; + if (v == mkldnn_BAcd16a16b) return "BAcd16a16b"; + if (v == mkldnn_format_tag_last) return "format_tag_last"; + if (v == mkldnn_x) return "x"; + if (v == mkldnn_nc) return "nc"; + if (v == mkldnn_cn) return "cn"; + if (v == mkldnn_ncw) return "ncw"; + if (v == mkldnn_nwc) return "nwc"; + if (v == mkldnn_nchw) return "nchw"; + if (v == mkldnn_nhwc) return "nhwc"; + if (v == mkldnn_chwn) return "chwn"; + if (v == mkldnn_ncdhw) return "ncdhw"; + if (v == mkldnn_ndhwc) return "ndhwc"; + if (v == mkldnn_oi) return "oi"; + if (v == mkldnn_io) return "io"; + if (v == mkldnn_oiw) return "oiw"; + if (v == mkldnn_wio) return "wio"; + if (v == mkldnn_oihw) return "oihw"; + if (v == mkldnn_hwio) return "hwio"; + if (v == mkldnn_ihwo) return "ihwo"; + if (v == mkldnn_iohw) return "iohw"; + if (v == mkldnn_oidhw) return "oidhw"; + if (v == mkldnn_dhwio) return "dhwio"; + if (v == mkldnn_goiw) return "goiw"; + if (v == mkldnn_goihw) return "goihw"; + if (v == mkldnn_hwigo) return "hwigo"; + if (v == mkldnn_giohw) return "giohw"; + if (v == mkldnn_goidhw) return "goidhw"; + if (v == mkldnn_tnc) return "tnc"; + if (v == mkldnn_ntc) return "ntc"; + if (v == mkldnn_ldsnc) return "ldsnc"; + if (v == mkldnn_ldigo) return "ldigo"; + if (v == mkldnn_ldgoi) return "ldgoi"; + if (v == mkldnn_ldgo) return "ldgo"; + if (v == mkldnn_nCdhw16c) return "nCdhw16c"; + if (v == mkldnn_nCdhw4c) return "nCdhw4c"; + if (v == mkldnn_nCdhw8c) return "nCdhw8c"; + if (v == mkldnn_nChw16c) return "nChw16c"; + if (v == mkldnn_nChw4c) return "nChw4c"; + if (v == mkldnn_nChw8c) return "nChw8c"; + if (v == mkldnn_nCw16c) return "nCw16c"; + if (v == mkldnn_nCw4c) return "nCw4c"; + if (v == mkldnn_nCw8c) return "nCw8c"; + if (v == mkldnn_IOw16o16i) return "IOw16o16i"; + if (v == mkldnn_OIw16i16o) return "OIw16i16o"; + if (v == mkldnn_OIw16o16i) return "OIw16o16i"; + if (v == mkldnn_Oiw16o) return "Oiw16o"; + if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i"; + if (v == mkldnn_OIw4i4o) return "OIw4i4o"; + if (v == mkldnn_Oiw4o) return "Oiw4o"; + if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i"; + if (v == mkldnn_OIw8i8o) return "OIw8i8o"; + if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o"; + if (v == mkldnn_OIw8o8i) return "OIw8o8i"; + if (v == mkldnn_Owi16o) return "Owi16o"; + if (v == mkldnn_Owi4o) return "Owi4o"; + if (v == mkldnn_Owi8o) return "Owi8o"; + if (v == mkldnn_IOhw16o16i) return "IOhw16o16i"; + if (v == mkldnn_Ohwi16o) return "Ohwi16o"; + if (v == mkldnn_Ohwi4o) return "Ohwi4o"; + if (v == mkldnn_Ohwi8o) return "Ohwi8o"; + if (v == mkldnn_OIhw16i16o) return "OIhw16i16o"; + if (v == mkldnn_OIhw16o16i) return "OIhw16o16i"; + if (v == mkldnn_Oihw16o) return "Oihw16o"; + if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i"; + if (v == mkldnn_OIhw4i4o) return "OIhw4i4o"; + if (v == mkldnn_Oihw4o) return "Oihw4o"; + if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i"; + if (v == mkldnn_OIhw8i8o) return "OIhw8i8o"; + if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o"; + if (v == mkldnn_OIhw8o8i) return "OIhw8o8i"; + if (v == mkldnn_Odhwi16o) return "Odhwi16o"; + if (v == mkldnn_Odhwi4o) return "Odhwi4o"; + if (v == mkldnn_Odhwi8o) return "Odhwi8o"; + if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o"; + if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i"; + if (v == mkldnn_Oidhw16o) return "Oidhw16o"; + if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o"; + if (v == mkldnn_Oidhw4o) return "Oidhw4o"; + if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i"; + if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o"; + if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i"; + if (v == mkldnn_Goiw16g) return "Goiw16g"; + if (v == mkldnn_gIOw16o16i) return "gIOw16o16i"; + if (v == mkldnn_gOIw16i16o) return "gOIw16i16o"; + if (v == mkldnn_gOIw16o16i) return "gOIw16o16i"; + if (v == mkldnn_gOiw16o) return "gOiw16o"; + if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i"; + if (v == mkldnn_gOIw4i4o) return "gOIw4i4o"; + if (v == mkldnn_gOiw4o) return "gOiw4o"; + if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i"; + if (v == mkldnn_gOIw8i8o) return "gOIw8i8o"; + if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o"; + if (v == mkldnn_gOIw8o8i) return "gOIw8o8i"; + if (v == mkldnn_gOwi16o) return "gOwi16o"; + if (v == mkldnn_gOwi4o) return "gOwi4o"; + if (v == mkldnn_gOwi8o) return "gOwi8o"; + if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i"; + if (v == mkldnn_gOhwi16o) return "gOhwi16o"; + if (v == mkldnn_gOhwi4o) return "gOhwi4o"; + if (v == mkldnn_gOhwi8o) return "gOhwi8o"; + if (v == mkldnn_Goihw16g) return "Goihw16g"; + if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o"; + if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i"; + if (v == mkldnn_gOihw16o) return "gOihw16o"; + if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i"; + if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i"; + if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o"; + if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i"; + if (v == mkldnn_gOihw4o) return "gOihw4o"; + if (v == mkldnn_Goihw8g) return "Goihw8g"; + if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i"; + if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o"; + if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o"; + if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i"; + if (v == mkldnn_gOdhwi16o) return "gOdhwi16o"; + if (v == mkldnn_gOdhwi4o) return "gOdhwi4o"; + if (v == mkldnn_gOdhwi8o) return "gOdhwi8o"; + if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o"; + if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i"; + if (v == mkldnn_gOidhw16o) return "gOidhw16o"; + if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o"; + if (v == mkldnn_gOidhw4o) return "gOidhw4o"; + if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i"; + if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o"; + if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i"; + assert(!"unknown fmt_tag"); + return "unknown fmt_tag"; +} + +const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) { + if (v == mkldnn_prop_kind_undef) return "undef"; + if (v == mkldnn_forward_training) return "forward_training"; + if (v == mkldnn_forward_inference) return "forward_inference"; + if (v == mkldnn_forward_scoring) return "forward_scoring"; + if (v == mkldnn_forward) return "forward"; + if (v == mkldnn_backward) return "backward"; + if (v == mkldnn_backward_data) return "backward_data"; + if (v == mkldnn_backward_weights) return "backward_weights"; + if (v == mkldnn_backward_bias) return "backward_bias"; + assert(!"unknown prop_kind"); + return "unknown prop_kind"; +} + +const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { + if (v == mkldnn_undefined_primitive) return "undef"; + if (v == mkldnn_reorder) return "reorder"; + if (v == mkldnn_shuffle) return "shuffle"; + if (v == mkldnn_concat) return "concat"; + if (v == mkldnn_sum) return "sum"; + if (v == mkldnn_convolution) return "convolution"; + if (v == mkldnn_deconvolution) return "deconvolution"; + if (v == mkldnn_eltwise) return "eltwise"; + if (v == mkldnn_softmax) return "softmax"; + if (v == mkldnn_pooling) return "pooling"; + if (v == mkldnn_lrn) return "lrn"; + if (v == mkldnn_batch_normalization) return "batch_normalization"; + if (v == mkldnn_inner_product) return "inner_product"; + if (v == mkldnn_rnn) return "rnn"; + assert(!"unknown prim_kind"); + return "unknown prim_kind"; +} + +const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { + if (v == mkldnn_alg_kind_undef) return "undef"; + if (v == mkldnn_convolution_direct) return "convolution_direct"; + if (v == mkldnn_convolution_winograd) return "convolution_winograd"; + if (v == mkldnn_convolution_auto) return "convolution_auto"; + if (v == mkldnn_deconvolution_direct) return "deconvolution_direct"; + if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd"; + if (v == mkldnn_eltwise_relu) return "eltwise_relu"; + if (v == mkldnn_eltwise_tanh) return "eltwise_tanh"; + if (v == mkldnn_eltwise_elu) return "eltwise_elu"; + if (v == mkldnn_eltwise_square) return "eltwise_square"; + if (v == mkldnn_eltwise_abs) return "eltwise_abs"; + if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt"; + if (v == mkldnn_eltwise_linear) return "eltwise_linear"; + if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu"; + if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu"; + if (v == mkldnn_eltwise_logistic) return "eltwise_logistic"; + if (v == mkldnn_pooling_max) return "pooling_max"; + if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; + if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; + if (v == mkldnn_pooling_avg) return "pooling_avg"; + if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; + if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; + if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; + if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; + if (v == mkldnn_vanilla_gru) return "vanilla_gru"; + if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset"; + assert(!"unknown alg_kind"); + return "unknown alg_kind"; +} + +const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) { + if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right"; + if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left"; + if (v == mkldnn_bidirectional_concat) return "bidirectional_concat"; + if (v == mkldnn_bidirectional_sum) return "bidirectional_sum"; + if (v == mkldnn_unidirectional) return "unidirectional"; + assert(!"unknown rnn_direction"); + return "unknown rnn_direction"; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp new file mode 100644 index 0000000000..7e5789e2c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_HPP +#define MKLDNN_THREAD_HPP + +#include "utils.hpp" +#include "z_magic.hpp" + +#define MKLDNN_THR_SEQ 0 +#define MKLDNN_THR_OMP 1 +#define MKLDNN_THR_TBB 2 + +/* Ideally this condition below should never happen (if the library is built + * using regular cmake). For the 3rd-party projects that build the library + * from the sources on their own try to guess the right threading... */ +#if !defined(MKLDNN_THR) +# define MKLDNN_THR MKLDNN_THR_TBB +#endif + +#if MKLDNN_THR == MKLDNN_THR_SEQ +#define MKLDNN_THR_SYNC 1 +inline int mkldnn_get_max_threads() { return 1; } +inline int mkldnn_get_num_threads() { return 1; } +inline int mkldnn_get_thread_num() { return 0; } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() {} + +#define PRAGMA_OMP(...) + +#elif MKLDNN_THR == MKLDNN_THR_OMP +#include <omp.h> +#define MKLDNN_THR_SYNC 1 + +inline int mkldnn_get_max_threads() { return omp_get_max_threads(); } +inline int mkldnn_get_num_threads() { return omp_get_num_threads(); } +inline int mkldnn_get_thread_num() { return omp_get_thread_num(); } +inline int mkldnn_in_parallel() { return omp_in_parallel(); } +inline void mkldnn_thr_barrier() { +# pragma omp barrier +} + +#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) + +#elif MKLDNN_THR == MKLDNN_THR_TBB +#include "tbb/task_arena.h" +#include "tbb/parallel_for.h" +#define MKLDNN_THR_SYNC 0 + +inline int mkldnn_get_max_threads() +{ return tbb::this_task_arena::max_concurrency(); } +inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); } +inline int mkldnn_get_thread_num() +{ return tbb::this_task_arena::current_thread_index(); } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); } + +#define PRAGMA_OMP(...) + +#endif + +/* MSVC still supports omp 2.0 only */ +#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define collapse(x) +# define PRAGMA_OMP_SIMD(...) +#else +# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) +#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER) + +namespace mkldnn { +namespace impl { + +inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; } + +template <typename T, typename U> +inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { + T n_min = 1; + T &n_my = n_end; + if (team <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else if (n_min == 1) { + // team = T1 + T2 + // n = T1*n1 + T2*n2 (n1 - n2 = 1) + T n1 = utils::div_up(n, (T)team); + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_my = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +} // namespace impl +} // namespace mkldnn + +#include "mkldnn_thread_parallel_nd.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp new file mode 100644 index 0000000000..50f9b29622 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP +#define MKLDNN_THREAD_PARALLEL_ND_HPP + +/* This header must be included by mkldnn_thread.hpp only */ + +/* Functions: + * - parallel(nthr, f) - executes f in parallel using at most + * nthr threads. If nthr equals 0 + * mkldnn_get_max_threads() threads is + * used + * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already + * created threads + * - parallel_nd(dims..., f) - creates a parallel section and then + * calls for_nd + * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then + * calls for_nd (mostly for convenience) + */ + +namespace mkldnn { +namespace impl { + +/* general parallelization */ +template <typename F> +void parallel(int nthr, F f) { + if (nthr == 0) nthr = mkldnn_get_max_threads(); +#if MKLDNN_THR == MKLDNN_THR_SEQ + assert(nthr == 1); + f(0, 1); +#elif MKLDNN_THR == MKLDNN_THR_OMP + if (nthr == 1) { f(0, 1); return; } +# pragma omp parallel num_threads(nthr) + f(mkldnn_get_thread_num(), mkldnn_get_num_threads()); +#elif MKLDNN_THR == MKLDNN_THR_TBB + if (nthr == 1) { f(0, 1); return; } + tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner()); +#endif +} + +/* for_nd section */ + +template <typename T0, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start{0}, end{0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) f(d0); +} + +template <typename T0, typename T1, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } +} + +template <typename T0, typename T1, typename T2, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } +} + +template <typename T0, typename T1, typename T2, typename T3, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename T5, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +// Skip a lambda function in the parameter pack. +template <typename T> +constexpr size_t get_work_amount(const T &v) { return 1; } +template <typename T, typename ...Args> +constexpr size_t get_work_amount(const T &v, Args &&...args) +{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); } + +/* parallel_nd and parallel_nd_in_omp section */ + +#if MKLDNN_THR != MKLDNN_THR_TBB +template <typename ...Args> +void parallel_nd(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward<Args>(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1; +# pragma omp parallel if (do_parallel) + { + const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads(); + const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num(); + for_nd(ithr, nthr, utils::forward<Args>(args)...); + } +#endif +} +#else // MKLDNN_THR != MKLDNN_THR_TBB + +// gcc 4.8 has a bug with passing parameter pack to lambdas. +// So have to explicitly instantiate all the cases. + +template <typename T0, typename F> +void parallel_nd(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(T0(iwork)); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename T3, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename T5, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } + }, tbb::static_partitioner()); +} +#endif + +template <typename ...Args> +void parallel_nd_in_omp(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward<Args>(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(), + utils::forward<Args>(args)...); +#elif MKLDNN_THR == MKLDNN_THR_TBB + assert(!"unsupported parallel_nd_in_omp()"); +#endif +} + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp new file mode 100644 index 0000000000..aa671a0b6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TRAITS_HPP +#define MKLDNN_TRAITS_HPP + +#include <assert.h> +#include <stdint.h> + +#include "mkldnn.h" +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +template <data_type_t> struct prec_traits {}; /* ::type -> float */ +template <typename> struct data_traits {}; /* ::data_type -> f32 */ +template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */ +template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */ + +template <> struct prec_traits<data_type::f32> { typedef float type; }; +template <> struct prec_traits<data_type::s32> { typedef int32_t type; }; +template <> struct prec_traits<data_type::s8> { typedef int8_t type; }; +template <> struct prec_traits<data_type::u8> { typedef uint8_t type; }; + +template <> struct data_traits<float> +{ static constexpr data_type_t data_type = data_type::f32; }; +template <> struct data_traits<int32_t> +{ static constexpr data_type_t data_type = data_type::s32; }; +template <> struct data_traits<int8_t> +{ static constexpr data_type_t data_type = data_type::s8; }; +template <> struct data_traits<uint8_t> +{ static constexpr data_type_t data_type = data_type::u8; }; + +template <> struct typesize_traits<4> { typedef float type; }; +template <> struct typesize_traits<2> { typedef int16_t type; }; +template <> struct typesize_traits<1> { typedef uint8_t type; }; + +#define PKIND_TRAITS_INST(op) \ +template <> struct pkind_traits<primitive_kind::op> { \ + typedef CONCAT2(op, _desc_t) desc_type; \ + static constexpr query_t query_d = query::CONCAT2(op, _d); \ +} +PKIND_TRAITS_INST(convolution); +PKIND_TRAITS_INST(deconvolution); +PKIND_TRAITS_INST(shuffle); +PKIND_TRAITS_INST(eltwise); +PKIND_TRAITS_INST(softmax); +PKIND_TRAITS_INST(pooling); +PKIND_TRAITS_INST(lrn); +PKIND_TRAITS_INST(batch_normalization); +PKIND_TRAITS_INST(inner_product); +PKIND_TRAITS_INST(rnn); +#undef PKIND_TRAITS_INST + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp new file mode 100644 index 0000000000..f89ea999e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef NSTL_HPP +#define NSTL_HPP + +#include <stdint.h> +#include <limits.h> +#include <float.h> + +#include <vector> +#include <map> + +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +void *malloc(size_t size, int alignment); +void free(void *p); + +struct c_compatible { + enum { default_alignment = 64 }; + static void *operator new(size_t sz) { + return malloc(sz, default_alignment); + } + static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; } + static void *operator new[](size_t sz) { + return malloc(sz, default_alignment); + } + static void operator delete(void *p) { free(p); } + static void operator delete[](void *p) { free(p); } +}; + +namespace nstl { + +template<typename T> +inline const T abs(const T& a) { + return a >= 0 ? a : -a; +} + +template<typename T> +inline const T& max(const T& a, const T& b) { + return a > b ? a : b; +} + +template<typename T> +inline const T& min(const T& a, const T& b) { + return a < b ? a : b; +} + +template<typename T> void swap(T& t1, T& t2) { + T tmp(t1); + t1 = t2; + t2 = tmp; +} + +// Rationale: MKL-DNN needs numeric limits implementation that does not +// generate dependencies on C++ run-time libraries. + +template<typename T> struct numeric_limits; + +template<> struct numeric_limits<float> { + static constexpr float lowest() { return -FLT_MAX; } + static constexpr float max() { return FLT_MAX; } +}; + +template<> struct numeric_limits<int32_t> { + static constexpr int lowest() { return INT32_MIN; } + static constexpr int max() { return INT32_MAX; } +}; + +template<> struct numeric_limits<int16_t> { + static constexpr int16_t lowest() { return INT16_MIN; } + static constexpr int16_t max() { return INT16_MAX; } +}; + +template<> struct numeric_limits<int8_t> { + static constexpr int8_t lowest() { return INT8_MIN; } + static constexpr int8_t max() { return INT8_MAX; } +}; + +template<> struct numeric_limits<uint8_t> { + static constexpr uint8_t lowest() { return 0; } + static constexpr uint8_t max() { return UINT8_MAX; } +}; + +template<typename T> struct is_integral +{ static constexpr bool value = false; }; +template<> struct is_integral<int32_t> { static constexpr bool value = true; }; +template<> struct is_integral<int16_t> { static constexpr bool value = true; }; +template<> struct is_integral<int8_t> { static constexpr bool value = true; }; +template<> struct is_integral<uint8_t> { static constexpr bool value = true; }; + +template <typename T, typename U> struct is_same +{ static constexpr bool value = false; }; +template <typename T> struct is_same<T, T> +{ static constexpr bool value = true; }; + +// Rationale: MKL-DNN needs container implementations that do not generate +// dependencies on C++ run-time libraries. +// +// Implementation philosophy: caller is responsible to check if the operation +// is valid. The only functions that have to return status are those that +// depend on memory allocation or similar operations. +// +// This means that e.g. an operator [] does not have to check for boundaries. +// The caller should have checked the boundaries. If it did not we crash and +// burn: this is a bug in MKL-DNN and throwing an exception would not have been +// recoverable. +// +// On the other hand, insert() or resize() or a similar operation needs to +// return a status because the outcome depends on factors external to the +// caller. The situation is probably also not recoverable also, but MKL-DNN +// needs to be nice and report "out of memory" to the users. + +enum nstl_status_t { + success = 0, + out_of_memory +}; + +template <typename T> class vector: public c_compatible { +private: + std::vector<T> _impl; +public: + typedef typename std::vector<T>::iterator iterator; + typedef typename std::vector<T>::const_iterator const_iterator; + typedef typename std::vector<T>::size_type size_type; + vector() {} + vector(size_type n): _impl(n) {} + vector(size_type n, const T &value): _impl(n, value) {} + template <typename input_iterator> + vector(input_iterator first, input_iterator last): _impl(first, last) {} + ~vector() {} + size_type size() const { return _impl.size(); } + T& operator[] (size_type i) { return _impl[i]; } + const T& operator[] (size_type i) const { return _impl[i]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template <typename input_iterator> + nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end) + { + _impl.insert(pos, begin, end); + return success; + } + void clear() { _impl.clear(); } + void push_back(const T& t) { _impl.push_back(t); } + void resize(size_type count) { _impl.resize(count); } + void reserve(size_type count) { _impl.reserve(count); } +}; + +template <typename Key, typename T> class map: public c_compatible { +private: + std::map<Key, T> _impl; +public: + typedef typename std::map<Key, T>::iterator iterator; + typedef typename std::map<Key, T>::const_iterator const_iterator; + typedef typename std::map<Key, T>::size_type size_type; + map() {} + ~map() {} + size_type size() const { return _impl.size(); } + T& operator[](const Key &k) { return _impl[k]; } + const T& operator[](const Key &k) const { return _impl[k]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template <typename input_iterator> + void clear() { _impl.clear(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp new file mode 100644 index 0000000000..be96e654ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t pooling_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l) + && one_of(alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto pd = pooling_desc_t(); + pd.primitive_kind = primitive_kind::pooling; + pd.prop_kind = prop_kind; + pd.alg_kind = alg_kind; + pd.src_desc.ndims = src_desc->ndims; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + pd.diff_src_desc = pd.src_desc = zero_md(); + pd.diff_dst_desc = pd.dst_desc = zero_md(); + + (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc; + (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(pd.strides, strides, sp_dims); + utils::array_copy(pd.kernel, kernel, sp_dims); + utils::array_copy(pd.padding[0], padding_l, sp_dims); + utils::array_copy(pd.padding[1], padding_r, sp_dims); + + pd.padding_kind = padding_kind; + if (one_of(alg_kind, pooling_max, pooling_avg_include_padding, + pooling_avg_exclude_padding)) { + pd.accum_data_type = types::default_accum_data_type( + src_desc->data_type, dst_desc->data_type); + } else { + pd.accum_data_type = dst_desc->data_type; + } + + bool consistency = true + && utils::one_of(src_desc->ndims, 4, 5) + && utils::one_of(dst_desc->ndims, 4, 5) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == dst_desc->dims[1]; + for (int i = 2; i < src_desc->ndims; ++i) + consistency = consistency && ( + (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2] + + padding_r[i - 2]) / strides[i - 2] + 1 + == dst_desc->dims[i]); + if (!consistency) return invalid_arguments; + + *pool_desc = pd; + return success; +} +} + +status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc, + dst_desc, strides, kernel, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_src_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t kernel, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind, + diff_src_desc, diff_dst_desc, strides, kernel, padding_l, + padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp new file mode 100644 index 0000000000..4c9c009412 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef POOLING_PD_HPP +#define POOLING_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct pooling_fwd_pd_t; + +struct pooling_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::pooling; + + pooling_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , ws_md_() + {} + + const pooling_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::pooling_d: + *(const pooling_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common pooling aux functions */ + + dim_t MB() const { return src_desc().dims[0]; } + dim_t C() const { return src_desc().dims[1]; } + + dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } + dim_t IW() const { return src_desc().dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } + dim_t OW() const { return dst_desc().dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } + dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } + dim_t KW() const { return desc_.kernel[ndims() - 3]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return src_desc().ndims; } + bool is_3d() const { return ndims() == 5; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(src_desc()).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + pooling_desc_t desc_; + const pooling_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; + + void init_default_ws() { + ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); + ws_md_.data_type = indices_data_type(); + } + + data_type_t indices_data_type() const { + /* the simplest way to express 256... */ + const int u8_max = nstl::numeric_limits< + typename prec_traits<data_type::u8>::type>::max(); + return utils::array_product(desc()->kernel, ndims()) <= u8_max + ? data_type::u8 : data_type::s32; + } + +private: + const memory_desc_t &src_desc() const + { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } + const memory_desc_t &dst_desc() const + { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } +}; + +struct pooling_fwd_pd_t: public pooling_pd_t { + typedef pooling_fwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_fwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } + +protected: + memory_desc_t src_md_; + memory_desc_t dst_md_; + + virtual status_t set_default_params() { + if (dst_md()->format_kind != format_kind::any) + return status::success; + + if (src_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(dst_md_, + src_md_.format_desc.blocking); + } +}; + +struct pooling_bwd_pd_t: public pooling_pd_t { + typedef pooling_bwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_bwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 1 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t diff_dst_md_; + + virtual status_t set_default_params() { + if (diff_src_md()->format_kind != format_kind::any) + return status::success; + + if (diff_dst_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(diff_src_md_, + diff_dst_md_.format_desc.blocking); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp new file mode 100644 index 0000000000..fdf6522f62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "primitive.hpp" +#include "type_helpers.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::primitive_kind; + +namespace { +// XXX: this is a huge hammer. This disables all and any msan checks on +// primitives outputs. +// +// A proper approach would be an implementation-specific unpoisoning. +void unpoison_outputs(const exec_args_t &args) { + for(const auto &arg: args) { + if (arg.second.is_const) continue; + auto *mem = arg.second.mem; + void *p; + mem->get_data_handle(&p); + size_t s = memory_desc_wrapper(*mem->md()).size(); + msan_unpoison(p, s); + } +} +} + +status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) { + if (primitive_desc) delete primitive_desc; + return success; +} + +status_t mkldnn_primitive_create(primitive_t **primitive, + const primitive_desc_t *primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return primitive_desc->create_primitive(primitive); +} + +status_t mkldnn_primitive_execute(const primitive_t *primitive, + stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) { + bool ok = true + && !utils::any_null(primitive, stream) + && primitive->engine() == stream->engine() + && IMPLICATION(nargs > 0, c_args != nullptr); + if (!ok) return invalid_arguments; + + exec_args_t args; + status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args); + if (status != status::success) return status; + + exec_ctx_t ctx(stream, std::move(args)); + + if (mkldnn_verbose()->level) { + double ms = get_msec(); + status = primitive->execute(ctx); + ms = get_msec() - ms; + printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms); + fflush(0); + } else { + status = primitive->execute(ctx); + } + + if (msan_enabled) unpoison_outputs(ctx.args()); + + return status; +} + +status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive, + const primitive_desc_t **primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return safe_ptr_assign<const primitive_desc_t>(*primitive_desc, + primitive->pd()); +} + +status_t mkldnn_primitive_destroy(primitive_t *primitive) { + if (primitive != nullptr) + delete primitive; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp new file mode 100644 index 0000000000..3b506d6d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_HPP +#define PRIMITIVE_HPP + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "primitive_exec_types.hpp" + +/** \brief A pure virtual primitive class + * + * Primitive contains links to its inputs & outputs, though it does not track + * their readiness on execution step. + * + * @remark @b Rational. + * Dependencies are essential through-out the whole MKL-DNN library, so it + * makes sense to include them on the very low level. On the other hand, + * tracking them should be a task for corresponding essence, like scheduler, + * stream or whatever. Primitive itself should know nothing about the + * environment it is running in. + * + * @note + * To make user experience better we should provide API which allows + * achieving the best (or good enough) performance when creating primitives + * in natural order: i.e. from bottom to top for forward pass and from top to + * bottom for backward pass. Please consider restriction [1] in Level 0. + */ +struct mkldnn_primitive: public mkldnn::impl::c_compatible { + mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd) + : pd_(pd->clone()) {} + virtual ~mkldnn_primitive() { delete pd_; } + + /** returns primitive's engine */ + mkldnn::impl::engine_t *engine() const { return pd_->engine(); } + /** returns primitive's inputs */ + const mkldnn::impl::primitive_desc_t *pd() const { return pd_; } + /** returns primitive's kind */ + mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); } + + /** executes primitive with execution context @p ctx */ + virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx) + const = 0; + +protected: + const mkldnn::impl::primitive_desc_t *pd_; + +private: + mkldnn_primitive() = delete; + mkldnn_primitive(const mkldnn_primitive &) = delete; + mkldnn_primitive(mkldnn_primitive &&) = delete; + mkldnn_primitive &operator=(const mkldnn_primitive &) = delete; + mkldnn_primitive &operator=(mkldnn_primitive &&) = delete; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp new file mode 100644 index 0000000000..9fd638842c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +namespace mkldnn { +namespace impl { + +status_t scales_t::set(dim_t count, int mask, const float *scales) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + scales_ = scales_buf_; + utils::array_set(scales_, scales[0], scales_buf_size); + } else { + scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); + if (scales_ == nullptr) + return status::out_of_memory; + + for (dim_t c = 0; c < count_; ++c) + scales_[c] = scales[c]; + } + + return status::success; +} + +} +} + +status_t post_ops_t::append_sum(float scale) { + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::sum; + entry_[len_].sum.scale = scale; + + len_++; + + return success; +} + +status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, + float beta) { + using namespace mkldnn::impl::alg_kind; + bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); + if (!known_alg) + return invalid_arguments; + + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::eltwise; + entry_[len_].eltwise.scale = scale; + entry_[len_].eltwise.alg = alg; + entry_[len_].eltwise.alpha = alpha; + entry_[len_].eltwise.beta = beta; + + len_++; + + return success; +} + +status_t primitive_attr_t::set_scratchpad_mode( + scratchpad_mode_t scratchpad_mode) { + using namespace mkldnn::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, library, user); + if (!ok) + return invalid_arguments; + + scratchpad_mode_ = scratchpad_mode; + return success; +} + +status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { + this->post_ops_ = post_ops; + return success; +} + +/* Public C API */ + +status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { + if (attr == nullptr) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_primitive_attr>(*attr, + new mkldnn_primitive_attr); +} + +status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, + const primitive_attr_t *existing_attr) { + if (any_null(attr, existing_attr)) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_primitive_attr>(*attr, + existing_attr->clone()); +} + +status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { + if (attr) + delete attr; + + return success; +} + +status_t mkldnn_primitive_attr_get_scratchpad_mode( + const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { + if (any_null(attr, scratchpad_mode)) + return invalid_arguments; + + *scratchpad_mode = attr->scratchpad_mode_; + + return success; +} + +status_t mkldnn_primitive_attr_set_scratchpad_mode( + primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { + if (any_null(attr)) + return invalid_arguments; + + return attr->set_scratchpad_mode(scratchpad_mode); +} + +status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, + dim_t *count, int *mask, const float **scales) { + if (any_null(attr, count, mask, scales)) + return invalid_arguments; + + *count = attr->output_scales_.count_; + *mask = attr->output_scales_.mask_; + *scales = attr->output_scales_.scales_; + + return success; +} + +status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, + dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_scales_.set(count, mask, scales); +} + +status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, + const post_ops_t **post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + *post_ops = &attr->post_ops_; + return success; +} + +status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, + const post_ops_t *post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + return attr->set_post_ops(*post_ops); +} + +status_t mkldnn_post_ops_create(post_ops_t **post_ops) { + if (post_ops == nullptr) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops); +} + +status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { + if (post_ops) + delete post_ops; + + return success; +} + +int mkldnn_post_ops_len(const post_ops_t *post_ops) { + if (post_ops) + return post_ops->len_; + + return 0; +} + +primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, + int index) { + bool ok = post_ops && 0 <= index && index < post_ops->len_; + if (!ok) + return primitive_kind::undefined; + + return post_ops->entry_[index].kind; +} + +status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_sum(scale); +} + +namespace { +bool simple_get_params_check(const post_ops_t *post_ops, int index, + primitive_kind_t kind) { + bool ok = true + && post_ops != nullptr + && 0 <= index + && index < post_ops->len_ + && post_ops->entry_[index].kind == kind; + return ok; +} +} + +status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, + float *scale) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::sum) + && !any_null(scale); + if (!ok) + return invalid_arguments; + + *scale = post_ops->entry_[index].sum.scale; + return success; +} + +status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, + alg_kind_t kind, float alpha, float beta) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_eltwise(scale, kind, alpha, beta); +} + +status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, + int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::eltwise) + && !any_null(scale, alpha, beta); + if (!ok) + return invalid_arguments; + + const auto &e = post_ops->entry_[index].eltwise; + *scale = e.scale; + *alg = e.alg; + *alpha = e.alpha; + *beta = e.beta; + + return success; +} + +status_t mkldnn_primitive_attr_set_rnn_data_qparams( + primitive_attr_t *attr, const float scale, const float shift) { + if (attr == nullptr) + return invalid_arguments; + + return attr->rnn_data_qparams_.set(scale, shift); +} + +status_t mkldnn_primitive_attr_set_rnn_weights_qparams( + primitive_attr_t *attr, dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->rnn_weights_qparams_.set(count, mask, scales); +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp new file mode 100644 index 0000000000..e2130c7ab1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_ATTR_HPP +#define PRIMITIVE_ATTR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_data_qparams_t : public c_compatible { + rnn_data_qparams_t() : scale_(1.), shift_(0.) {} + bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } + + status_t set(float scale, float shift) { + scale_ = scale; + shift_ = shift; + return status::success; + } + + float scale_; + float shift_; +}; + +struct scales_t: public c_compatible { + scales_t(): count_(1), mask_(0), scales_(scales_buf_) + { set(1.); } + + scales_t(const scales_t &rhs): scales_t() + { set(rhs.count_, rhs.mask_, rhs.scales_); } + + ~scales_t() { cleanup(); } + + scales_t &operator=(const scales_t &rhs) { + if (&rhs == this) + return *this; + status_t status = set(rhs.count_, rhs.mask_, rhs.scales_); + assert(status == status::success); + (void)status; + return *this; + } + + bool has_default_values() const { + for (dim_t c = 0; c < count_; ++c) { + if(scales_[c] != 1.) return false; + } + return true; + } + + status_t set(dim_t count, int mask, const float *scales); + status_t set(float single_scale) { return this->set(1, 0, &single_scale); } + + dim_t count_; + int mask_; + float *scales_; + +private: + enum { scales_buf_size = 16 }; + float scales_buf_[scales_buf_size]; + + void cleanup() { + if (scales_ != scales_buf_ && scales_ != nullptr) + impl::free(scales_); + + count_ = 1; + mask_ = 0; + scales_ = scales_buf_; + } +}; + +} +} + +struct mkldnn_post_ops: public mkldnn::impl::c_compatible { + struct entry_t { + struct eltwise_t { + mkldnn::impl::alg_kind_t alg; + float scale, alpha, beta; + }; + + mkldnn::impl::primitive_kind_t kind; + union { + struct { float scale; } sum; + eltwise_t eltwise; + }; + + bool is_eltwise(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::eltwise + && IMPLICATION(require_scale_one, eltwise.scale == 1.f); + } + + bool is_relu(bool require_scale_one = true, + bool require_nslope_zero = true) const { + using namespace mkldnn::impl; + return is_eltwise(require_scale_one) + && eltwise.alg == alg_kind::eltwise_relu + && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); + } + + bool is_sum(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::sum + && IMPLICATION(require_scale_one, sum.scale == 1.f); + } + }; + + mkldnn_post_ops(): len_(0) {} + + mkldnn::impl::status_t append_sum(float scale); + mkldnn::impl::status_t append_eltwise(float scale, + mkldnn::impl::alg_kind_t alg, float alpha, float beta); + + int find(mkldnn::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len_; + stop = mkldnn::impl::nstl::min(stop, len_); + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) return idx; + return -1; + } + + bool has_default_values() const { return len_ == 0; } + + bool contain(mkldnn::impl::primitive_kind_t kind, int index) const + { return find(kind, index, index + 1) == index; } + + enum { capacity = 4 }; + + int len_; + entry_t entry_[capacity]; +}; + +struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible { + mkldnn_primitive_attr() + : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library) + {} + + mkldnn_primitive_attr *clone() const + { return new mkldnn_primitive_attr(*this); } + + /** Returns true if the attributes have default values. + * + * @note The scratchpad_mode_ is not take into account */ + bool has_default_values() const { + return true + && output_scales_.has_default_values() + && post_ops_.has_default_values() + && rnn_data_qparams_.has_default_values() + && rnn_weights_qparams_.has_default_values(); + } + + mkldnn::impl::status_t set_scratchpad_mode( + mkldnn::impl::scratchpad_mode_t scratchpad_mode); + mkldnn::impl::status_t set_post_ops( + const mkldnn::impl::post_ops_t &post_ops); + + mkldnn::impl::scratchpad_mode_t scratchpad_mode_; + mkldnn::impl::scales_t output_scales_; + mkldnn::impl::post_ops_t post_ops_; + mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_; + mkldnn::impl::scales_t rnn_weights_qparams_; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp new file mode 100644 index 0000000000..723c41e05a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t primitive_desc_t::query(query_t what, int idx, void *result) const { + auto safe_ret_md = [&](const memory_desc_t *_) { + if (_ == nullptr) return not_required; + *(const memory_desc_t **)result = _; + return success; + }; + + switch (what) { + case query::engine: *(engine_t**)result = engine(); break; + case query::primitive_kind: *(primitive_kind_t*)result = kind(); break; + + case query::scratchpad_engine: + *(engine_t**)result = scratchpad_engine(); break; + + case query::memory_consumption_s64: + *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break; + + case query::op_d: + if (idx != 0 || op_desc() == nullptr) return invalid_arguments; + *(const_c_op_desc_t *)result + = static_cast<const_c_op_desc_t>(op_desc()); break; + + case query::src_md: return safe_ret_md(src_md(idx)); + case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); + case query::dst_md: return safe_ret_md(dst_md(idx)); + case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); + case query::weights_md: return safe_ret_md(weights_md(idx)); + case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx)); + case query::workspace_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(workspace_md(idx)); + case query::scratchpad_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(scratchpad_md(idx)); + + case query::num_of_inputs_s32: *(int*)result = n_inputs(); break; + case query::num_of_outputs_s32: *(int*)result = n_outputs(); break; + + case query::impl_info_str: *(const char **)result = name(); break; + + default: return unimplemented; + } + return success; +} + +status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc, + const primitive_attr_t **attr) { + if (utils::any_null(primitive_desc, attr)) + return invalid_arguments; + + *attr = primitive_desc->attr(); + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp new file mode 100644 index 0000000000..536dcfa1d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp @@ -0,0 +1,174 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_DESC_HPP +#define PRIMITIVE_DESC_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "primitive_attr.hpp" +#include "verbose.hpp" + +struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { + using md_t = mkldnn::impl::memory_desc_t; + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), kind_(kind) { info_[0] = '\0'; } + + virtual mkldnn_primitive_desc *clone() const = 0; + virtual ~mkldnn_primitive_desc() {} + + const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } + mkldnn::impl::engine_t *engine() const { return engine_; } + mkldnn::impl::primitive_kind_t kind() const { return kind_; } + + virtual void init_info() {} + const char *info() const { return info_; } + + mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() + { return scratchpad_registry_; } + const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const + { return scratchpad_registry_; } + virtual mkldnn::impl::engine_t *scratchpad_engine() const + { return engine_; } + + virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } + + enum class arg_usage_t { unused, input, output }; + virtual arg_usage_t arg_usage( + mkldnn::impl::primitive_arg_index_t arg) const { + using mkldnn::impl::types::is_zero_md; + if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) + return arg_usage_t::output; + return arg_usage_t::unused; + } + +# define DECLARE_MD_STUB(stub) \ + virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ + { return nullptr; } + + DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); + DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); + DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); + DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); + DECLARE_MD_STUB(workspace_md); +# undef DECLARE_MD_STUB + + const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { + return idx == 0 ? &scratchpad_md_ : nullptr; + } + + virtual void init_scratchpad_md() { + auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); + mkldnn::impl::dims_t dims = { size }; + mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, + mkldnn::impl::data_type::u8, mkldnn_x); + } + + /** returns the scratchpad size for the given scratchpad mode. */ + mkldnn::impl::dim_t scratchpad_size( + mkldnn::impl::scratchpad_mode_t mode) const { + if (mode != attr_.scratchpad_mode_) return 0; + return scratchpad_registry().size(); + } + + virtual int n_inputs() const { return 0; } + virtual int n_outputs() const { return 0; } + + virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, + void *result) const; + + virtual mkldnn::impl::status_t create_primitive( + mkldnn::impl::primitive_t **primitive) const = 0; + + virtual const char *name() const { return "mkldnn_primitive_desc"; } + + /* static magic */ + + template<typename pd_t> + static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, + const mkldnn::impl::op_desc_t *adesc, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_desc_t *hint_fwd) { + using namespace mkldnn::impl; + using namespace mkldnn::impl::status; + using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type; + if (adesc->kind != pd_t::base_pkind) return invalid_arguments; + assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); + auto hint = + reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd); + auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->init_info(); + _pd->init_scratchpad_md(); + *pd = _pd; + return success; + } + +protected: + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_attr_t attr_; + mkldnn::impl::primitive_kind_t kind_; + + mkldnn::impl::memory_desc_t scratchpad_md_; + + char info_[MKLDNN_VERBOSE_BUF_LEN]; + + mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; + +protected: + /** compares ws between fwd_pd and this (make sense to use for bwd_pd) + * Expectation: this already set workspace, and this workspace should + * exactly match the one from fwd_pd */ + bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { + using namespace mkldnn::impl; + if (!workspace_md()) return true; // the impl lives fine w/o workspace + return fwd_pd && fwd_pd->workspace_md() + && *fwd_pd->workspace_md() == *workspace_md(); + } +}; + +#define DECLARE_COMMON_PD_t(impl_name, ...) \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual const char *name() const override { return impl_name; } +#define DECLARE_COMMON_PD_T(impl_name, ...) \ + DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp new file mode 100644 index 0000000000..43e5a31ef3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "memory.hpp" +#include "primitive.hpp" +#include "primitive_exec_types.hpp" + +namespace mkldnn { +namespace impl { + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args) { + using namespace status; + + if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; + + int n_inputs = 0; + int n_outputs = 0; + + for (int i = 0; i < nargs; ++i) { + primitive_arg_index_t arg = c_args[i].arg; + auto *mem = c_args[i].memory; + + switch (pd->arg_usage(arg)) { + case primitive_desc_t::arg_usage_t::input: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, true}; + n_inputs++; + break; + case primitive_desc_t::arg_usage_t::output: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, false}; + n_outputs++; + break; + case primitive_desc_t::arg_usage_t::unused: + break; + } + } + + bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md()); + + if (n_inputs != pd->n_inputs()) return invalid_arguments; + if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0)) + return invalid_arguments; + + return success; +} + +const void *exec_ctx_t::input(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +void *exec_ctx_t::output(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(!ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const { + assert(args_.count(arg) == 1); + const auto ma = args_.at(arg); + assert(!ma.is_const); + return ma.mem; +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp new file mode 100644 index 0000000000..0645891da7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_EXEC_TYPES_HPP +#define PRIMITIVE_EXEC_TYPES_HPP + +#include <unordered_map> + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct memory_arg_t { + memory_t *mem; + bool is_const; +}; + +using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>; + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args); + +/** Primitive execution context (helps passing stream, memories, and events. */ +struct exec_ctx_t { + exec_ctx_t(const exec_ctx_t &) = default; + exec_ctx_t(exec_ctx_t &&) = default; + + exec_ctx_t(stream_t *stream): stream_(stream) {} + exec_ctx_t(stream_t *stream, exec_args_t &&args) + : stream_(stream) + , args_(std::move(args)) {} + + stream_t *stream() const { return stream_; } + const exec_args_t &args() const { return args_; } + + /* tentative solution... TODO: replace with functions return memory_t */ + const void *input(primitive_arg_index_t arg) const; + void *output(primitive_arg_index_t arg) const; + + const memory_t *memory(primitive_arg_index_t arg) const; + +private: + stream_t *stream_; + exec_args_t args_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp new file mode 100644 index 0000000000..5a1cd7d379 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_iterator_create( + primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc, + const primitive_attr_t *attr, engine_t *engine, + const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd); + if (it == nullptr) return out_of_memory; + + ++(*it); + if (*it == it->end()) { + delete it; + return unimplemented; + } + + *iterator = it; + return success; +} + +status_t mkldnn_primitive_desc_iterator_next( + primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return invalid_arguments; + ++(*iterator); + return *iterator == iterator->end() ? iterator_ends : success; +} + +primitive_desc_t *mkldnn_primitive_desc_iterator_fetch( + const primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return nullptr; + return *(*iterator); +} + +status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc, + const primitive_desc_t *existing_primitive_desc) { + if (utils::any_null(primitive_desc, existing_primitive_desc)) + return invalid_arguments; + return safe_ptr_assign<primitive_desc_t>(*primitive_desc, + existing_primitive_desc->clone()); +} + +status_t mkldnn_primitive_desc_iterator_destroy( + primitive_desc_iterator_t *iterator) { + if (iterator != nullptr) + delete iterator; + return success; +} + +status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc, + const_c_op_desc_t c_op_desc, const primitive_attr_t *attr, + engine_t *engine, const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd); + ++it; + if (it == it.end()) return unimplemented; + + return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp new file mode 100644 index 0000000000..4e88ab3aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#ifndef PRIMITIVE_ITERATOR_HPP +#define PRIMITIVE_ITERATOR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { + using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, + const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) + : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) + , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) + , impl_list_(engine_->get_implementation_list()), last_idx_(0) + { + while (impl_list_[last_idx_] != nullptr) ++last_idx_; + } + ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } + + bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } + bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return !operator==(rhs); } + + mkldnn::impl::primitive_desc_iterator_t end() const + { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } + + mkldnn::impl::primitive_desc_iterator_t &operator++() { + if (pd_) { delete pd_; pd_ = nullptr; } + while (++idx_ != last_idx_) { + auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, + hint_fwd_pd_); + if (s == mkldnn::impl::status::success) break; + } + return *this; + } + + mkldnn::impl::primitive_desc_t *operator*() const { + if (*this == end() || pd_ == nullptr) return nullptr; + return pd_->clone(); + } + +protected: + int idx_; + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_desc_t *pd_; + const mkldnn::impl::op_desc_t *op_desc_; + const mkldnn::impl::primitive_attr_t attr_; + const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; + const pd_create_f *impl_list_; + int last_idx_; + +private: + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) + : idx_(last_idx), engine_(engine), pd_(nullptr) + , op_desc_(nullptr), hint_fwd_pd_(nullptr) + , impl_list_(nullptr), last_idx_(last_idx) {} +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp new file mode 100644 index 0000000000..835cd73581 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc, + query_t what, int index, void *result) { + if (any_null(primitive_desc, result)) + return invalid_arguments; + + return primitive_desc->query(what, index, result); +} + +const memory_desc_t *mkldnn_primitive_desc_query_md( + const primitive_desc_t *primitive_desc, query_t what, int index) { + const memory_desc_t *res_md = nullptr; + bool args_ok = true + && primitive_desc != nullptr + && (what & query::some_md) == query::some_md + && what != query::some_md + && mkldnn_primitive_desc_query(primitive_desc, + what, index, &res_md) == success; + return args_ok ? res_md : nullptr; +} + +int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc, + query_t what, int index) { + int res_s32; + bool args_ok = primitive_desc != nullptr + && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32) + && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32) + == success; + return args_ok ? res_s32 : 0; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp new file mode 100644 index 0000000000..d11f1a0361 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "reorder_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_reorder_primitive_desc_create( + primitive_desc_t **reorder_pd, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md, + const primitive_attr_t *attr) { + if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md)) + return invalid_arguments; + + auto s_ek = src_engine->kind(); + auto d_ek = dst_engine->kind(); + if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek))) + return invalid_arguments; + + auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd); + auto s_mdw = memory_desc_wrapper(*src_md); + auto d_mdw = memory_desc_wrapper(*dst_md); + + if (!s_mdw.consistent_with(d_mdw)) + return invalid_arguments; + + auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + for (auto r = e->get_reorder_implementation_list(); *r; ++r) { + if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md) + == success) { + (*r_pd)->init_info(); + (*r_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp new file mode 100644 index 0000000000..963cb0f58a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REORDER_PD_HPP +#define REORDER_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct reorder_pd_t: public primitive_desc_t { + reorder_pd_t(engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) + : primitive_desc_t(engine, attr, primitive_kind::reorder) + , src_engine_(src_engine) + , dst_engine_(dst_engine) + , scratchpad_engine_(nullptr) + , src_md_(*src_md) + , dst_md_(*dst_md) + {} + + virtual const op_desc_t *op_desc() const override { return nullptr; } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_FROM) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_TO) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + float alpha() const { return attr()->output_scales_.scales_[0]; } + float beta() const { + const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); + return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; + } + virtual mkldnn::impl::engine_t *scratchpad_engine() const override + { return scratchpad_engine_; } + +protected: + engine_t *src_engine_; + engine_t *dst_engine_; + engine_t *scratchpad_engine_; + + memory_desc_t src_md_; + memory_desc_t dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp new file mode 100644 index 0000000000..36967431a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu/gemm/os_blas.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::utils; + +namespace { +memory_desc_t copy_maybe_null(const memory_desc_t *md) { + return md ? *md : zero_md(); +} + +rnn_desc_t zero_rnn_desc() { + auto rd = rnn_desc_t(); + rd.src_layer_desc = zero_md(); + rd.src_iter_desc = zero_md(); + rd.weights_layer_desc = zero_md(); + rd.weights_iter_desc = zero_md(); + rd.bias_desc = zero_md(); + rd.dst_layer_desc = zero_md(); + rd.dst_iter_desc = zero_md(); + rd.diff_src_layer_desc = zero_md(); + rd.diff_src_iter_desc = zero_md(); + rd.diff_weights_layer_desc = zero_md(); + rd.diff_weights_iter_desc = zero_md(); + rd.diff_bias_desc = zero_md(); + rd.diff_dst_layer_desc = zero_md(); + rd.diff_dst_iter_desc = zero_md(); + return rd; +} +} + +/* Public C Api */ + +status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, + unsigned int flags, float alpha, float clipping) { + using namespace mkldnn::impl::alg_kind; + + bool args_ok = true + && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, + gru_linear_before_reset) + && IMPLICATION(cell_kind == vanilla_rnn, + one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); + if (!args_ok) + return invalid_arguments; + + auto rcd = mkldnn_rnn_cell_desc_t(); + + rcd.cell_kind = cell_kind; + rcd.activation_kind = act_f; + rcd.flags = flags; + rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; + rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; + + *rnn_cell_desc = rcd; + + return success; +} + +int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 3; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; + case mkldnn::impl::alg_kind::vanilla_lstm: return 4; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 1; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; + case mkldnn::impl::alg_kind::vanilla_lstm: return 2; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, + prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + using namespace data_type; + data_type_t src_layer_dt = src_layer_desc->data_type; + data_type_t dst_layer_dt = dst_layer_desc->data_type; + data_type_t weights_iter_dt = weights_iter_desc->data_type; + data_type_t weights_layer_dt = weights_layer_desc->data_type; + + bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, + weights_layer_dt) + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + +#if USE_MKL_PACKED_GEMM + bool is_u8u8u8 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == u8) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == u8) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_f32u8f32 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_inference = prop_kind == prop_kind::forward_inference; + bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; + + return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) + ? success + : unimplemented; +#else + return is_f32 ? success : unimplemented; +#endif +} + +status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, + rnn_direction_t direction, int L, int D, int T, int N, int S, int G, + int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok; + + // * algorithm specific + args_ok = true + && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, + DIC == SIC); + if (!args_ok) return invalid_arguments; + int extra_bias = + rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; + + // * on num layers + args_ok = true + && L == weights_layer_desc->dims[0] + && L == weights_iter_desc->dims[0] + && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) + && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); + if (!args_ok) return invalid_arguments; + + // * on num directions + args_ok = true + && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) + && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) + && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); + if (!args_ok) return invalid_arguments; + + // * on num iterations + args_ok = true + && T == src_layer_desc->dims[0] + && T == dst_layer_desc->dims[0]; + if (!args_ok) return invalid_arguments; + + // * on mb + args_ok = true + && N == src_layer_desc->dims[1] + && N == dst_layer_desc->dims[1] + && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); + if (!args_ok) return invalid_arguments; + + // * on num gates + args_ok = true + && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) + && G == weights_layer_desc->dims[3] + && G == weights_iter_desc->dims[3] + && IMPLICATION(!is_zero_md(bias_desc), + G + extra_bias == bias_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on num states + args_ok = true + && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) + && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) + && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on slc + args_ok = true + && SLC == weights_layer_desc->dims[2] + && SLC == src_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on sic + args_ok = true + && SIC == weights_iter_desc->dims[2] + && IMPLICATION(!is_zero_md(src_iter_desc), + SIC == src_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * on dlc + int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; + args_ok = true + && DLC == dlc_multiplier * DIC + && DLC == dst_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on dic + args_ok = true + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), + DIC == dst_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * unrolling/fusion conditions + args_ok = true + && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) + && IMPLICATION(T > 1, SIC == DIC); + if (!args_ok) return invalid_arguments; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok = true && rnn_cell_desc != nullptr + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc); + if (!args_ok) return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc)); + + CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, + src_layer_desc, src_iter_desc, weights_layer_desc, + weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); + + // Create the descriptor + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + + *rnn_desc = rd; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, + const memory_desc_t *diff_src_layer_desc, + const memory_desc_t *diff_src_iter_desc, + const memory_desc_t *diff_weights_layer_desc, + const memory_desc_t *diff_weights_iter_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_layer_desc, + const memory_desc_t *diff_dst_iter_desc) { + bool args_ok = true + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc, diff_src_layer_desc, + diff_weights_layer_desc, diff_weights_iter_desc, + diff_dst_layer_desc); + if (!args_ok) + return invalid_arguments; + + auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { + return is_zero_md(a_md) == is_zero_md(b_md); + }; + + args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) + && xnor_md(dst_iter_desc, diff_dst_iter_desc) + && xnor_md(src_iter_desc, diff_src_iter_desc); + if (!args_ok) + return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc); + if (st != success) return st; + + st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, + diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc); + if (st != success) return st; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); + rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); + rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); + rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); + rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); + rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); + rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); + + *rnn_desc = rd; + + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp new file mode 100644 index 0000000000..1ee2ba1114 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_PD_HPP +#define RNN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_fwd_pd_t; + +struct rnn_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::rnn; + + rnn_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , src_layer_md_(desc_.src_layer_desc) + , src_iter_md_(desc_.src_iter_desc) + , weights_layer_md_(desc_.weights_layer_desc) + , weights_iter_md_(desc_.weights_iter_desc) + , bias_md_(desc_.bias_desc) + , dst_layer_md_(desc_.dst_layer_desc) + , dst_iter_md_(desc_.dst_iter_desc) + , ws_md_() + {} + + const rnn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &src_layer_md_; + if (index == 1 && with_src_iter()) return &src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_layer_md_; + if (index == 1) return &weights_iter_md_; + if (index == 2 && with_bias()) return &bias_md_; + return nullptr; + } + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &dst_layer_md_; + if (index == 1 && with_dst_iter()) return &dst_iter_md_; + return nullptr; + } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + /* common pooling aux functions */ + + bool is_training() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::backward); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + dim_t T() const { return desc_.src_layer_desc.dims[0]; } + dim_t MB() const { return desc_.src_layer_desc.dims[1]; } + + dim_t L() const { return desc_.weights_layer_desc.dims[0]; } + dim_t D() const { return desc_.weights_layer_desc.dims[1]; } + + dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } + + dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } + dim_t G() const { return desc_.weights_layer_desc.dims[3]; } + dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; } + + dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } + + bool with_bias() const + { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); } + + bool with_src_iter() const + { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); } + + bool with_dst_iter() const + { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); } + + mkldnn::impl::alg_kind_t cell_kind() const + { return desc_.cell_desc.cell_kind; } + mkldnn::impl::alg_kind_t activation_kind() const + { return desc_.cell_desc.activation_kind; } + + bool is_lbr() const + { return cell_kind() == mkldnn_gru_linear_before_reset; } + + mkldnn_rnn_direction_t direction() const { return desc_.direction; } + +protected: + rnn_desc_t desc_; + const rnn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t src_layer_md_; + memory_desc_t src_iter_md_; + memory_desc_t weights_layer_md_; + memory_desc_t weights_iter_md_; + memory_desc_t bias_md_; + memory_desc_t dst_layer_md_; + memory_desc_t dst_iter_md_; + + memory_desc_t ws_md_; +}; + +struct rnn_fwd_pd_t: public rnn_pd_t { + typedef rnn_fwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_fwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC_LAYER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter()) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST_LAYER) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter()) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual int n_inputs() const override + { return 3 + with_bias() + with_src_iter(); } + virtual int n_outputs() const override + { return 1 + with_dst_iter() + is_training(); } +}; + +struct rnn_bwd_pd_t : public rnn_pd_t { + typedef rnn_bwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_bwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_layer_md_(desc_.diff_src_layer_desc) + , diff_src_iter_md_(desc_.diff_src_iter_desc) + , diff_weights_layer_md_(desc_.diff_weights_layer_desc) + , diff_weights_iter_md_(desc_.diff_weights_iter_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_layer_md_(desc_.diff_dst_layer_desc) + , diff_dst_iter_md_(desc_.diff_dst_iter_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER, + MKLDNN_ARG_DIFF_DST_LAYER)) + return arg_usage_t::input; + + if (with_src_iter()) { + if (arg == MKLDNN_ARG_SRC_ITER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC_ITER) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (with_bias()) { + if (arg == MKLDNN_ARG_BIAS) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_BIAS) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER) + && with_dst_iter()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_ITER)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override { + if (index == 0) return &diff_src_layer_md_; + if (index == 1 && with_src_iter()) return &diff_src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *diff_weights_md( + int index = 0) const override { + if (index == 0) return &diff_weights_layer_md_; + if (index == 1) return &diff_weights_iter_md_; + if (index == 2 && with_bias()) return &diff_bias_md_; + return nullptr; + } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override { + if (index == 0) return &diff_dst_layer_md_; + if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; + return nullptr; + } + + virtual int n_inputs() const override + { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); } + virtual int n_outputs() const override + { return 3 + with_src_iter() + with_bias(); } + +protected: + memory_desc_t diff_src_layer_md_; + memory_desc_t diff_src_iter_md_; + memory_desc_t diff_weights_layer_md_; + memory_desc_t diff_weights_iter_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_layer_md_; + memory_desc_t diff_dst_iter_md_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp new file mode 100644 index 0000000000..6bc14fc72a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "scratchpad.hpp" + +namespace mkldnn { +namespace impl { + +/* Allocating memory buffers on a page boundary to reduce TLB/page misses */ +const size_t page_size = 2097152; + +/* + Implementation of the scratchpad_t interface that is compatible with + a concurrent execution +*/ +struct concurent_scratchpad_t : public scratchpad_t { + concurent_scratchpad_t(size_t size) { + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + + ~concurent_scratchpad_t() { + free(scratchpad_); + } + + virtual char *get() const { + return scratchpad_; + } + +private: + char *scratchpad_; + size_t size_; +}; + +/* + Implementation of the scratchpad_t interface that uses a global + scratchpad +*/ + +struct global_scratchpad_t : public scratchpad_t { + global_scratchpad_t(size_t size) { + if (size > size_) { + if (scratchpad_ != nullptr) free(scratchpad_); + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + reference_count_++; + } + + ~global_scratchpad_t() { + reference_count_--; + if (reference_count_ == 0) { + free(scratchpad_); + scratchpad_ = nullptr; + size_ = 0; + } + } + + virtual char *get() const { + return scratchpad_; + } + +private: + /* + Using thread-local here is unnecessary and even buggy! All threads + actually share the same scratchpad, which is created and queried only + on the main thread. If the scratchpad is queried on some thread other + than the one it was created on (e.g. the application calls the API from + multiple threads), thread-local causes a segfault because the scratchpad + is uninitialized on the current thread. + */ + /*thread_local*/ static char *scratchpad_; + /*thread_local*/ static size_t size_; + /*thread_local*/ static unsigned int reference_count_; +}; + +/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr; +/*thread_local*/ size_t global_scratchpad_t::size_ = 0; +/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0; + + +/* + Scratchpad creation routine +*/ +scratchpad_t *create_scratchpad(size_t size) { +#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC + return new global_scratchpad_t(size); +#else + return new concurent_scratchpad_t(size); +#endif +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp new file mode 100644 index 0000000000..f7a246bc99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_SCRATCHPAD_HPP +#define COMMON_SCRATCHPAD_HPP + +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct scratchpad_t { + virtual ~scratchpad_t() {} + virtual char *get() const = 0; +}; + +scratchpad_t *create_scratchpad(size_t size); + +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp new file mode 100644 index 0000000000..e32e735224 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, int axis, dim_t group_size) { + bool args_ok = true + && !any_null(shuffle_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward, backward_data) + && axis >= 0 && axis < data_desc->ndims + && group_size > 0 && group_size <= data_desc->dims[axis]; + if (!args_ok) return invalid_arguments; + + auto sd = shuffle_desc_t(); + sd.primitive_kind = primitive_kind::shuffle; + sd.prop_kind = prop_kind; + sd.data_desc = *data_desc; + sd.axis = axis; + sd.group_size = group_size; + + bool consistency = true + && sd.data_desc.dims[axis] % sd.group_size == 0; + if (!consistency) return invalid_arguments; + + *shuffle_desc = sd; + return success; +} +} + +status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis, + dim_t group_size) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis, + group_size); +} + +status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc, + const memory_desc_t *diff_data_desc, int axis, dim_t group_size) { + return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis, + group_size); +} + +// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp new file mode 100644 index 0000000000..cc5553fe7f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SHUFFLE_PD_HPP +#define SHUFFLE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct shuffle_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::shuffle; + + typedef shuffle_pd_t base_class; + typedef shuffle_pd_t hint_class; + + shuffle_pd_t(engine_t *engine, + const shuffle_desc_t *adesc, + const primitive_attr_t *attr, + const shuffle_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const shuffle_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::shuffle_d: + *(const shuffle_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (is_fwd()) { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + } else { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + } + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + /* shuffle aux functions */ + + dim_t MB() const { return data_md()->dims[0]; } + dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } + dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } + + int ndims() const { return data_md()->ndims; } + + int axis() const { return desc_.axis; } + dim_t group_size() const { return desc_.group_size; } + dim_t axis_size() const { return data_md()->dims[axis()]; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + const memory_desc_t *data_md() const { return &data_md_; } + +protected: + shuffle_desc_t desc_; + const shuffle_pd_t *hint_fwd_pd_; + memory_desc_t data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp new file mode 100644 index 0000000000..82848e3d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) { + bool args_ok = true + && !any_null(softmax_desc, data_desc) + && 0 <= softmax_axis + && softmax_axis < data_desc->ndims; + if (!args_ok) return invalid_arguments; + + auto sd = softmax_desc_t(); + sd.primitive_kind = primitive_kind::softmax; + sd.prop_kind = prop_kind; + + bool is_bwd = (sd.prop_kind == backward_data); + sd.data_desc = *data_desc; + sd.diff_desc = is_bwd ? *diff_desc : zero_md(); + sd.softmax_axis = softmax_axis; + + *softmax_desc = sd; + return success; +} +} + +status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + int softmax_axis) { + if (!one_of(prop_kind, forward_inference, forward_training)) + return invalid_arguments; + return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis); +} + +status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc, + const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, + int softmax_axis) { + return softmax_desc_init(softmax_desc, prop_kind::backward_data, + data_desc, diff_desc, softmax_axis); +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp new file mode 100644 index 0000000000..8a16ce901c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SOFTMAX_PD_HPP +#define SOFTMAX_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct softmax_fwd_pd_t; + +struct softmax_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::softmax; + + softmax_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const softmax_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::softmax_d: + *(const softmax_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common softmax aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + softmax_desc_t desc_; + const softmax_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct softmax_fwd_pd_t: public softmax_pd_t { + typedef softmax_fwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_fwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct softmax_bwd_pd_t: public softmax_pd_t { + typedef softmax_bwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_bwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp new file mode 100644 index 0000000000..00af8935c0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +/* API */ + +status_t mkldnn_stream_create(stream_t **stream, engine_t *engine, + unsigned flags) { + bool args_ok = true + && !utils::any_null(stream, engine) + && flags == stream_flags::default_flags; + if (!args_ok) + return invalid_arguments; + + return safe_ptr_assign<stream_t>(*stream, new stream_t(engine, flags)); +} + +status_t mkldnn_stream_destroy(stream_t *stream) { + delete stream; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp new file mode 100644 index 0000000000..f010e5f6ed --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef STREAM_HPP +#define STREAM_HPP + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" + +struct mkldnn_stream: public mkldnn::impl::c_compatible { + mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags) + : engine_(engine), flags_(flags) {} + virtual ~mkldnn_stream() {} + + /** returns stream's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + + /** returns stream's kind */ + unsigned flags() const { return flags_; } + +protected: + mkldnn::impl::engine_t *engine_; + unsigned flags_; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp new file mode 100644 index 0000000000..365663c0f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "sum_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds, const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != dims[d]) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd); + + for (auto s = engine->get_sum_implementation_list(); *s; ++s) { + if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) { + (*s_pd)->init_info(); + (*s_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp new file mode 100644 index 0000000000..80254667df --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SUM_PD_HPP +#define SUM_PD_HPP + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct sum_pd_t: public primitive_desc_t { + sum_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::sum) + , n_(n), dst_md_(*dst_md) + { + scales_.reserve(n_); + for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]); + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + const float *scales() const { return &scales_[0]; } + +protected: + int n_; + nstl::vector<float> scales_; + memory_desc_t dst_md_; + nstl::vector<memory_desc_t> src_mds_; + +protected: + /* inits dst_md_ in simple cases. The call may fail. */ + status_t init() { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(&src_mds_[i]); + if (!src_d.is_blocking_desc() || src_d.is_additional_buffer()) + return status::unimplemented; + } + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain, pick the format of the first input + */ + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (!src_d.is_plain() && src_d.is_blocking_desc()) { + return memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + } + } + + if (src_mds_[0].format_kind != format_kind::blocked) + return status::unimplemented; + + dst_md_ = src_mds_[0]; + + return status::success; + } +}; + +#define DECLARE_SUM_PD_t(impl_name, ...) \ + static status_t create(sum_pd_t **sum_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, const float *scales, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_SUM_PD_T(impl_name, ...) \ + DECLARE_SUM_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp new file mode 100644 index 0000000000..a408f45980 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TAG_TRAITS_HPP +#define TAG_TRAITS_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +enum class block_dim_t { + _, + _A, _B, + _AB, _BC, +}; + +enum class inner_blk_t { + _, + _4a, _4b, + _8a, _8b, + _16a, _16b, + + _4b4a, _4b4c, _4c4b, + _8a8b, _8b8a, _8b8c, _8c8b, + _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b, + + _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c, +}; + +/** returns the offset within the block for weights blocked over oc and ic */ +template <inner_blk_t f> +constexpr int AB_or_BC_blk_off(int x0, int x1) { + using ib = inner_blk_t; + static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b, + ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b, + ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c), + "unexpected inner_blk format"); + return false ? 0 + : (f == ib::_4b4c) ? 4 * x0 + x1 + : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 + : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 + : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 + : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 + : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 + : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 + : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 + : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 + : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 + : INT_MIN; +} + +template <inner_blk_t b> struct inner_blk_traits { + using ib = inner_blk_t; +}; + +template <format_tag_t> struct tag_traits { + // block_dim_t block_dims; + // inner_blk_t inner_blks; + // int ndims; +}; + +#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ +template <> struct tag_traits<format_tag::_tag> { \ + static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ + static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ + static constexpr int ndims = _ndims; \ +} + +DECL_TRAITS(a, _, _, 1); +DECL_TRAITS(ab, _, _, 2); +DECL_TRAITS(abc, _, _, 3); +DECL_TRAITS(abcd, _, _, 4); +DECL_TRAITS(abcde, _, _, 5); +DECL_TRAITS(abcdef, _, _, 6); +DECL_TRAITS(abdec, _, _, 5); +DECL_TRAITS(acb, _, _, 3); +DECL_TRAITS(acbde, _, _, 5); +DECL_TRAITS(acdb, _, _, 4); +DECL_TRAITS(acdeb, _, _, 5); +DECL_TRAITS(ba, _, _, 2); +DECL_TRAITS(bac, _, _, 3); +DECL_TRAITS(bacd, _, _, 4); +DECL_TRAITS(bcda, _, _, 4); +DECL_TRAITS(cba, _, _, 3); +DECL_TRAITS(cdba, _, _, 4); +DECL_TRAITS(cdeba, _, _, 5); +DECL_TRAITS(decab, _, _, 5); + +DECL_TRAITS(Abc4a, _A, _4a, 3); +DECL_TRAITS(aBc4b, _B, _4b, 3); +DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3); +DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3); +DECL_TRAITS(Abcd4a, _A, _4a, 4); +DECL_TRAITS(aBcd4b, _B, _4b, 4); +DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4); +DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4); +DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4); +DECL_TRAITS(Abcde4a, _A, _4a, 5); +DECL_TRAITS(aBcde4b, _B, _4b, 5); +DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5); +DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5); +DECL_TRAITS(aBcdef4b, _B, _4b, 6); +DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6); +DECL_TRAITS(aBdc4b, _B, _4b, 4); +DECL_TRAITS(aBdec4b, _B, _4b, 5); +DECL_TRAITS(aBdefc4b, _B, _4b, 6); +DECL_TRAITS(Acb4a, _A, _4a, 3); +DECL_TRAITS(Acdb4a, _A, _4a, 4); +DECL_TRAITS(Acdeb4a, _A, _4a, 5); + +DECL_TRAITS(Abc16a, _A, _16a, 3); +DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(aBc16b, _B, _16b, 3); +DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3); +DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3); +DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3); +DECL_TRAITS(aBc8b, _B, _8b, 3); +DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3); +DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3); +DECL_TRAITS(Abcd16a, _A, _16a, 4); +DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4); +DECL_TRAITS(aBcd16b, _B, _16b, 4); +DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4); +DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4); +DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4); +DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); +DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(aBcd8b, _B, _8b, 4); +DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); +DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4); +DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4); +DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4); +DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4); +DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4); +DECL_TRAITS(Abcde16a, _A, _16a, 5); +DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5); +DECL_TRAITS(aBcde16b, _B, _16b, 5); +DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5); +DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5); +DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5); +DECL_TRAITS(Abcde8a, _A, _8a, 5); +DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5); +DECL_TRAITS(aBcde8b, _B, _8b, 5); +DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5); +DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5); +DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5); +DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5); +DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5); +DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5); +DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5); +DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5); +DECL_TRAITS(aBcdef16b, _B, _16b, 6); +DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6); +DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6); +DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6); +DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6); +DECL_TRAITS(aBdc16b, _B, _16b, 4); +DECL_TRAITS(aBdc8b, _B, _8b, 4); +DECL_TRAITS(aBdec16b, _B, _16b, 5); +DECL_TRAITS(aBdec8b, _B, _8b, 5); +DECL_TRAITS(aBdefc16b, _B, _16b, 6); +DECL_TRAITS(aBdefc8b, _B, _8b, 6); +DECL_TRAITS(Acb16a, _A, _16a, 3); +DECL_TRAITS(Acb8a, _A, _8a, 3); +DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(Acdb16a, _A, _16a, 4); +DECL_TRAITS(Acdb8a, _A, _8a, 4); +DECL_TRAITS(Acdeb16a, _A, _16a, 5); +DECL_TRAITS(Acdeb8a, _A, _8a, 5); +DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp new file mode 100644 index 0000000000..4f06368738 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_HELPERS_HPP +#define TYPE_HELPERS_HPP + +#include <assert.h> +#include <math.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "mkldnn_traits.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "math_utils.hpp" + +namespace mkldnn { +namespace impl { + +template <typename T> +status_t safe_ptr_assign(T * &lhs, T* rhs) { + if (rhs == nullptr) return status::out_of_memory; + lhs = rhs; + return status::success; +} + +template <typename T, typename U> struct is_subset +{ static constexpr bool value = false; }; +template <typename T> struct is_subset<T, T> +{ static constexpr bool value = true; }; +template <typename T> struct is_subset<T, + typename utils::enable_if<nstl::is_integral<T>::value, float>::type> +{ static constexpr bool value = true; }; +#define ISSPEC(t1, t2) template <> \ + struct is_subset<t1, t2> { static constexpr bool value = true; } +ISSPEC(int16_t, int32_t); +ISSPEC(int8_t, int32_t); +ISSPEC(uint8_t, int32_t); +ISSPEC(int8_t, int16_t); +ISSPEC(uint8_t, int16_t); +#undef ISSPEC + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); + +namespace types { + +inline size_t data_type_size(data_type_t data_type) { + using namespace data_type; + switch (data_type) { + case f32: return sizeof(prec_traits<f32>::type); + case s32: return sizeof(prec_traits<s32>::type); + case s8: return sizeof(prec_traits<s8>::type); + case u8: return sizeof(prec_traits<u8>::type); + case data_type::undef: + default: assert(!"unknown data_type"); + } + return 0; /* not supposed to be reachable */ +} + +inline format_kind_t format_tag_to_kind(format_tag_t tag) { + switch (tag) { + case format_tag::undef: return format_kind::undef; + case format_tag::any: return format_kind::any; + case format_tag::last: return format_kind::undef; + default: return format_kind::blocked; + } + + assert(!"unreachable"); + return format_kind::undef; +} + +inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs, + const memory_extra_desc_t &rhs) { + return true + && lhs.flags == rhs.flags + && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8, + lhs.compensation_mask == rhs.compensation_mask) + && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust, + lhs.scale_adjust == rhs.scale_adjust); +} + +inline bool blocking_desc_is_equal(const blocking_desc_t &lhs, + const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) { + using mkldnn::impl::utils::array_cmp; + return true + && lhs.inner_nblks == rhs.inner_nblks + && array_cmp(lhs.strides, rhs.strides, ndims) + && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) + && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); +} + +inline bool wino_desc_is_equal(const wino_desc_t &lhs, + const wino_desc_t &rhs) { + return lhs.wino_format == rhs.wino_format + && lhs.alpha == rhs.alpha + && lhs.ic == rhs.ic + && lhs.oc == rhs.oc + && lhs.ic_block == rhs.ic_block + && lhs.oc_block == rhs.oc_block + && lhs.ic2_block == rhs.ic2_block + && lhs.oc2_block == rhs.oc2_block + && lhs.r == rhs.r; +} + +inline bool rnn_packed_desc_is_equal( + const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { + bool ok = true + && lhs.format == rhs.format + && lhs.n_parts == rhs.n_parts + && lhs.offset_compensation == rhs.offset_compensation + && lhs.size == rhs.size + && lhs.n == rhs.n; + if (!ok) + return false; + + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.parts[i] == rhs.parts[i]; + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; + return ok; +} + +inline memory_desc_t zero_md() { + auto zero = memory_desc_t(); + return zero; +} + +inline bool is_zero_md(const memory_desc_t *md) { + return md == nullptr || *md == zero_md(); +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t dst_dt) { + using namespace utils; + using namespace data_type; + + if (one_of(f32, src_dt, dst_dt)) return f32; + if (one_of(s32, src_dt, dst_dt)) return s32; + + if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { + using namespace utils; + using namespace data_type; + using namespace prop_kind; + + /* prop_kind doesn't matter */ + if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32; + + if (one_of(prop_kind, forward_training, forward_inference)) { + if ((src_dt == u8 || src_dt == s8) + && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8)) + return s32; + } else if (prop_kind == backward_data) { + if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && + one_of(dst_dt, s8, u8)) + return s32; + } + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +} + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { + using namespace mkldnn::impl::utils; + bool base_equal = true + && lhs.ndims == rhs.ndims + && array_cmp(lhs.dims, rhs.dims, lhs.ndims) + && lhs.data_type == rhs.data_type + && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) + && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) + && lhs.offset0 == rhs.offset0 + && lhs.format_kind == rhs.format_kind; + if (!base_equal) return false; + if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; + if (lhs.format_kind == format_kind::blocked) + return types::blocking_desc_is_equal(lhs.format_desc.blocking, + rhs.format_desc.blocking, lhs.ndims); + else if (lhs.format_kind == format_kind::wino) + return types::wino_desc_is_equal(lhs.format_desc.wino_desc, + rhs.format_desc.wino_desc); + else if (lhs.format_kind == format_kind::rnn_packed) + return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, + rhs.format_desc.rnn_packed_desc); + return true; +} + +inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { + return !operator==(lhs, rhs); +} + +inline status_t memory_desc_init_by_strides(memory_desc_t &md, + const dims_t strides) { + return mkldnn_memory_desc_init_by_strides( + &md, md.ndims, md.dims, md.data_type, strides); +} + +inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + status_t status = mkldnn_memory_desc_init_by_tag( + &md, md.ndims, md.dims, md.data_type, tag); + if (status != status::success || strides == nullptr) + return status; + + /* TODO: add consistency check */ + + for (int d = 0; d < md.ndims; ++d) + md.format_desc.blocking.strides[d] = strides[d]; + + return status::success; +} + +/** inits memory descriptor based on logical dimensions kept in @p md, and the + * blocking structure @p blk. + * + * @note blk.strides represent the order only (from smaller to bigger) + * + * TODO: move md related functions to one single place + */ +inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md, + const blocking_desc_t &blk) { + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + dim_t block_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; + block_size *= blk.inner_blks[iblk]; + } + + for (int d = 0; d < md.ndims; ++d) { + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + md.padded_offsets[d] = 0; + } + md.offset0 = 0; + + md.format_kind = format_kind::blocked; + auto &mblk = md.format_desc.blocking; + mblk = blk; + + const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy + utils::array_copy(mblk.strides, blk.strides, ndims); + + int perm[MKLDNN_MAX_NDIMS]; + for (int d = 0; d < ndims; ++d) perm[d] = d; + + utils::simultaneous_sort(mblk.strides, perm, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + dim_t stride = block_size; + for (int _d = ndims - 1; _d >= 0; --_d) { + const int d = perm[_d]; + md.format_desc.blocking.strides[d] = stride; + stride *= md.padded_dims[d] / blocks[d]; + } + + md.extra = utils::zero<memory_extra_desc_t>(); + + return status::success; +} + +/** returns true if memory desc @p md corresponds to the given format tag and + * strides. + * If strides are not passed (or passed as nullptr) the dense structure is + * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns). + * Strides might contain `0` value, indicating the stride must match the one + * that mkldnn_memory_desc_init_by_tag() returns. + * Strides might contain `-1` values, that would be ignored during the + * comparison. For instance, this can be used if a stride along minibatch + * doesn't matter. */ +inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + if (md.format_kind != types::format_tag_to_kind(tag)) + return false; + + memory_desc_t md_gold; + status_t status = mkldnn_memory_desc_init_by_tag( + &md_gold, md.ndims, md.dims, md.data_type, tag); + if (status != status::success) return false; + + if (md.format_kind != format_kind::blocked) + return false; // unimplemented yet + + const auto &blk = md.format_desc.blocking; + const auto &blk_gold = md_gold.format_desc.blocking; + + using utils::array_cmp; + bool same_blocks = true + && blk.inner_nblks == blk_gold.inner_nblks + && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); + + if (!same_blocks) + return false; + + if (strides == nullptr) + return array_cmp(blk.strides, blk_gold.strides, md.ndims); + + for (int d = 0; d < md.ndims; ++d) { + dim_t stride = strides[d]; + if (stride == -1) continue; + if (stride == 0) stride = blk_gold.strides[d]; + if (blk.strides[d] != stride) return false; + } + + return true; +} + +/** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ +template <typename... Tags> +format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md, + Tags ...tags) { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(md, tag)) + return tag; + } + return format_tag::undef; +} + +} +} + +#include "memory_desc_wrapper.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp new file mode 100644 index 0000000000..d23f4682dc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <string.h> +#ifdef _WIN32 +#include <malloc.h> +#include <windows.h> +#endif +#include <limits.h> +#include <stdlib.h> +#include <stdio.h> + +#include "mkldnn.h" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +int getenv(const char *name, char *buffer, int buffer_size) { + if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0)) + return INT_MIN; + + int result = 0; + int term_zero_idx = 0; + size_t value_length = 0; + +#ifdef _WIN32 + value_length = GetEnvironmentVariable(name, buffer, buffer_size); +#else + const char *value = ::getenv(name); + value_length = value == NULL ? 0 : strlen(value); +#endif + + if (value_length > INT_MAX) + result = INT_MIN; + else { + int int_value_length = (int)value_length; + if (int_value_length >= buffer_size) { + result = -int_value_length; + } else { + term_zero_idx = int_value_length; + result = int_value_length; +#ifndef _WIN32 + strncpy(buffer, value, value_length); +#endif + } + } + + if (buffer != NULL) + buffer[term_zero_idx] = '\0'; + return result; +} + +int getenv_int(const char *name, int default_value) +{ + int value = default_value; + // # of digits in the longest 32-bit signed int + sign + terminating null + const int len = 12; + char value_str[len]; + if (getenv(name, value_str, len) > 0) + value = atoi(value_str); + return value; +} + +FILE *fopen(const char *filename, const char *mode) { +#ifdef _WIN32 + FILE *fp = NULL; + return ::fopen_s(&fp, filename, mode) ? NULL : fp; +#else + return ::fopen(filename, mode); +#endif +} + +void *malloc(size_t size, int alignment) { + void *ptr; + +#ifdef _WIN32 + ptr = _aligned_malloc(size, alignment); + int rc = ptr ? 0 : -1; +#else + int rc = ::posix_memalign(&ptr, alignment, size); +#endif + + return (rc == 0) ? ptr : 0; +} + +void free(void *p) { +#ifdef _WIN32 + _aligned_free(p); +#else + ::free(p); +#endif +} + +// Atomic operations +int32_t fetch_and_add(int32_t *dst, int32_t val) { +#ifdef _WIN32 + return InterlockedExchangeAdd(reinterpret_cast<long*>(dst), val); +#else + return __sync_fetch_and_add(dst, val); +#endif +} + +static int jit_dump_flag = 0; +static bool jit_dump_flag_initialized = false; +bool jit_dump_enabled() { + if (!jit_dump_flag_initialized) { + jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP"); + jit_dump_flag_initialized = true; + } + return jit_dump_flag != 0; +} + +} +} + +mkldnn_status_t mkldnn_set_jit_dump(int enabled) { + using namespace mkldnn::impl::status; + mkldnn::impl::jit_dump_flag = enabled; + mkldnn::impl::jit_dump_flag_initialized = true; + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp new file mode 100644 index 0000000000..d5a8ec5139 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp @@ -0,0 +1,370 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef UTILS_HPP +#define UTILS_HPP + +#include <stddef.h> +#include <stdio.h> +#include <stdlib.h> +#include <assert.h> +#include <stdint.h> + +#if defined(__x86_64__) || defined(_M_X64) +#define MKLDNN_X86_64 +#endif + +#define MSAN_ENABLED 0 +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) +#undef MSAN_ENABLED +#define MSAN_ENABLED 1 +#include <sanitizer/msan_interface.h> +#endif +#endif + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +// Sanity check for 64 bits +static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only"); + +#define CHECK(f) do { \ + status_t status = f; \ + if (status != status::success) \ + return status; \ +} while (0) + +#define IMPLICATION(cause, effect) (!(cause) || !!(effect)) + +namespace utils { + +/* a bunch of std:: analogues to be compliant with any msvs version + * + * Rationale: msvs c++ (and even some c) headers contain special pragma that + * injects msvs-version check into object files in order to abi-mismatches + * during the static linking. This makes sense if e.g. std:: objects are passed + * through between application and library, which is not the case for mkl-dnn + * (since there is no any c++-rt dependent stuff, ideally...). */ + +/* SFINAE helper -- analogue to std::enable_if */ +template<bool expr, class T = void> struct enable_if {}; +template<class T> struct enable_if<true, T> { typedef T type; }; + +/* analogue std::conditional */ +template <bool, typename, typename> struct conditional {}; +template <typename T, typename F> struct conditional<true, T, F> +{ typedef T type; }; +template <typename T, typename F> struct conditional<false, T, F> +{ typedef F type; }; + +template <bool, typename, bool, typename, typename> struct conditional3 {}; +template <typename T, typename FT, typename FF> +struct conditional3<true, T, false, FT, FF> { typedef T type; }; +template <typename T, typename FT, typename FF> +struct conditional3<false, T, true, FT, FF> { typedef FT type; }; +template <typename T, typename FT, typename FF> +struct conditional3<false, T, false, FT, FF> { typedef FF type; }; + +template <bool, typename U, U, U> struct conditional_v {}; +template <typename U, U t, U f> struct conditional_v<true, U, t, f> +{ static constexpr U value = t; }; +template <typename U, U t, U f> struct conditional_v<false, U, t, f> +{ static constexpr U value = f; }; + +template <typename T> struct remove_reference { typedef T type; }; +template <typename T> struct remove_reference<T&> { typedef T type; }; +template <typename T> struct remove_reference<T&&> { typedef T type; }; + +template <typename T> +inline T&& forward(typename utils::remove_reference<T>::type &t) +{ return static_cast<T&&>(t); } +template <typename T> +inline T&& forward(typename utils::remove_reference<T>::type &&t) +{ return static_cast<T&&>(t); } + +template <typename T> +inline typename remove_reference<T>::type zero() +{ auto zero = typename remove_reference<T>::type(); return zero; } + +template <typename T, typename P> +inline bool everyone_is(T val, P item) { return val == item; } +template <typename T, typename P, typename... Args> +inline bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} + +template <typename T, typename P> +constexpr bool one_of(T val, P item) { return val == item; } +template <typename T, typename P, typename... Args> +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template <typename... Args> +inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } + +template<typename T> +inline void array_copy(T *dst, const T *src, size_t size) { + for (size_t i = 0; i < size; ++i) dst[i] = src[i]; +} +template<typename T> +inline bool array_cmp(const T *a1, const T *a2, size_t size) { + for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false; + return true; +} +template<typename T, typename U> +inline void array_set(T *arr, const U& val, size_t size) { + for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val); +} + +namespace product_impl { +template<size_t> struct int2type{}; + +template <typename T> +constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; } + +template <typename T, size_t num> +inline T product_impl(const T *arr, int2type<num>) { + return arr[0]*product_impl(arr+1, int2type<num-1>()); } +} + +template <size_t num, typename T> +inline T array_product(const T *arr) { + return product_impl::product_impl(arr, product_impl::int2type<num-1>()); +} + +template<typename T, typename R = T> +inline R array_product(const T *arr, size_t size) { + R prod = 1; + for (size_t i = 0; i < size; ++i) prod *= arr[i]; + return prod; +} + +/** sorts an array of values using @p comparator. While sorting the array + * of value, the function permutes an array of @p keys accordingly. + * + * @note The arrays of @p keys can be omitted. In this case the function + * sorts the array of @vals only. + */ +template <typename T, typename U, typename F> +inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) { + if (size == 0) return; + + for (size_t i = 0; i < size - 1; ++i) { + bool swapped = false; + + for (size_t j = 0; j < size - i - 1; j++) { + if (comparator(vals[j], vals[j + 1]) > 0) { + nstl::swap(vals[j], vals[j + 1]); + if (keys) nstl::swap(keys[j], keys[j + 1]); + swapped = true; + } + } + + if (swapped == false) break; + } +} + +template <typename T, typename U> +inline typename remove_reference<T>::type div_up(const T a, const U b) { + assert(b); + return (a + b - 1) / b; +} + +template <typename T, typename U> +inline typename remove_reference<T>::type rnd_up(const T a, const U b) { + return div_up(a, b) * b; +} + +template <typename T, typename U> +inline typename remove_reference<T>::type rnd_dn(const T a, const U b) { + return (a / b) * b; +} + +template <typename T> T *align_ptr(T *ptr, uintptr_t alignment) +{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); } + +template <typename T, typename U, typename V> +inline U this_block_size(const T offset, const U max, const V block_size) { + assert(offset < max); + // TODO (Roma): can't use nstl::max() due to circular dependency... we + // need to fix this + const T block_boundary = offset + block_size; + if (block_boundary > max) + return max - offset; + else + return block_size; +} + +template<typename T> +inline T nd_iterator_init(T start) { return start; } +template<typename T, typename U, typename W, typename... Args> +inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) { + start = nd_iterator_init(start, utils::forward<Args>(tuple)...); + x = start % X; + return start / X; +} + +inline bool nd_iterator_step() { return true; } +template<typename U, typename W, typename... Args> +inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) { + if (nd_iterator_step(utils::forward<Args>(tuple)...) ) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template<typename U, typename W, typename Y> +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X) +{ + U max_jump = end - cur; + U dim_jump = X - x; + if (dim_jump <= max_jump) { + x = 0; + cur += dim_jump; + return true; + } else { + cur += max_jump; + x += max_jump; + return false; + } +} +template<typename U, typename W, typename Y, typename... Args> +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X, + Args &&... tuple) +{ + if (nd_iterator_jump(cur, end, utils::forward<Args>(tuple)...)) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template <typename T> +inline T pick(size_t i, const T &x0) { return x0; } +template <typename T, typename ...Args> +inline T pick(size_t i, const T &x0, Args &&... args) { + return i == 0 ? x0 : pick(i - 1, utils::forward<Args>(args)...); +} + +template <typename T> +T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference, + const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) { + switch (prop_kind) { + case prop_kind::forward_inference: return val_fwd_inference; + case prop_kind::forward_training: return val_fwd_training; + case prop_kind::backward_data: return val_bwd_d; + case prop_kind::backward_weights: return val_bwd_w; + default: assert(!"unsupported prop_kind"); + } + return T(); +} + +template <typename T> +T pick_by_prop_kind(prop_kind_t prop_kind, + const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w) +{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); } + +template <typename Telem, size_t Tdims> +struct array_offset_calculator { + template <typename... Targs> + array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... } + { + _base_ptr = base; + } + template <typename... Targs> + inline Telem &operator()(Targs... Fargs) + { + return *(_base_ptr + _offset(1, Fargs...)); + } + +private: + template <typename... Targs> + inline size_t _offset(size_t const dimension, size_t element) + { + return element; + } + + template <typename... Targs> + inline size_t _offset(size_t const dimension, size_t theta, size_t element) + { + return element + (_dims[dimension] * theta); + } + + template <typename... Targs> + inline size_t _offset(size_t const dimension, size_t theta, size_t element, + Targs... Fargs) + { + size_t t_prime = element + (_dims[dimension] * theta); + return _offset(dimension + 1, t_prime, Fargs...); + } + + Telem *_base_ptr; + const int _dims[Tdims]; +}; + +} + +int32_t fetch_and_add(int32_t *dst, int32_t val); +inline void yield_thread() {} + +// Reads an environment variable 'name' and stores its string value in the +// 'buffer' of 'buffer_size' bytes on success. +// +// - Returns the length of the environment variable string value (excluding +// the terminating 0) if it is set and its contents (including the terminating +// 0) can be stored in the 'buffer' without truncation. +// +// - Returns negated length of environment variable string value and writes +// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to +// store the value (including the terminating 0) without truncation. +// +// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment +// variable is not set. +// +// - Returns INT_MIN if the 'name' is NULL. +// +// - Returns INT_MIN if the 'buffer_size' is negative. +// +// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than +// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to +// retrieve the length of the environment variable value string. +// +int getenv(const char *name, char *buffer, int buffer_size); +// Reads an integer from the environment +int getenv_int(const char *name, int default_value = 0); +bool jit_dump_enabled(); +FILE *fopen(const char *filename, const char *mode); + +constexpr int msan_enabled = MSAN_ENABLED; +inline void msan_unpoison(void *ptr, size_t size) { +#if MSAN_ENABLED + __msan_unpoison(ptr, size); +#endif +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp new file mode 100644 index 0000000000..89a57772cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp @@ -0,0 +1,665 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <stdlib.h> +#ifndef _WIN32 +#include <sys/time.h> +#endif + +#include "mkldnn.h" +#include "mkldnn_version.h" +#include "c_types_map.hpp" +#include "verbose.hpp" +#include "cpu/cpu_isa_traits.hpp" + +#include "batch_normalization_pd.hpp" +#include "pooling_pd.hpp" +#include "concat_pd.hpp" +#include "reorder_pd.hpp" +#include "convolution_pd.hpp" +#include "rnn_pd.hpp" +#include "deconvolution_pd.hpp" +#include "shuffle_pd.hpp" +#include "eltwise_pd.hpp" +#include "softmax_pd.hpp" +#include "inner_product_pd.hpp" +#include "sum_pd.hpp" +#include "lrn_pd.hpp" + +/* MKL-DNN CPU ISA info */ +#define ISA_ANY "No instruction set specific optimizations" +#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)" +#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)" +#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)" +#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512)" +#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions" +#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \ + "AVX512-DL Boost)" +#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions" +#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions" + +namespace mkldnn { +namespace impl { + +static verbose_t verbose; +static bool initialized; +static bool version_printed = false; + +const verbose_t *mkldnn_verbose() { +#if !defined(DISABLE_VERBOSE) + if (!initialized) { + const int len = 2; + char val[len] = {0}; + if (getenv("MKLDNN_VERBOSE", val, len) == 1) + verbose.level = atoi(val); + initialized = true; + } + if (!version_printed && verbose.level > 0) { + printf("mkldnn_verbose,info," + "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n", + mkldnn_version()->major, mkldnn_version()->minor, + mkldnn_version()->patch, mkldnn_version()->hash, + get_isa_info()); + version_printed = true; + } +#else + verbose.level = 0; +#endif + return &verbose; +} + +double get_msec() { +#ifdef _WIN32 + static LARGE_INTEGER frequency; + if (frequency.QuadPart == 0) + QueryPerformanceFrequency(&frequency); + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + return 1e+3 * now.QuadPart / frequency.QuadPart; +#else + struct timeval time; + gettimeofday(&time, NULL); + return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; +#endif +} + +const char *get_isa_info() { + using namespace mkldnn::impl::cpu; + if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS; + if (mayiuse(avx512_mic)) return AVX512_MIC; + if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI; + if (mayiuse(avx512_core)) return AVX512_CORE; + if (mayiuse(avx512_common)) return AVX512_COMMON; + if (mayiuse(avx2)) return AVX2; + if (mayiuse(avx)) return AVX; + if (mayiuse(sse42)) return SSE42; + return ISA_ANY; +} + +/* init_info section */ +namespace { +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_DAT_LEN 256 +#define MKLDNN_VERBOSE_AUX_LEN 384 +#define MKLDNN_VERBOSE_PRB_LEN 384 + +#define DECL_DAT_AUX_PRB_STRS() \ + int dat_written = 0, aux_written = 0, prb_written = 0; \ + MAYBE_UNUSED((dat_written * aux_written * prb_written)); \ + char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \ + char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \ + char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str) + +#define DFMT "%" PRId64 + +void clear_buf(char *buf, int &written) { + /* TODO: do it better */ + buf[0] = '#'; + buf[1] = '\0'; + written = 1; +} + +#define DPRINT(buf, buf_len, written, ...) do { \ + int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \ + if (l < 0 || written + l > buf_len) { \ + clear_buf(buf, written); \ + } else { \ + written += l; \ + } \ +} while(0) + +// XXX: Outputs strings corresponding to memory formats used for data tensors. +void format_prb_desc_str(char *str, int len, const memory_desc_t *md) { + const auto dims = md->dims; + int written = 0; + if (md->ndims == 1) + DPRINT(str, len, written, + "x" DFMT, dims[0]); + else if (md->ndims == 2) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT, dims[0], dims[1]); + else if (md->ndims == 3) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "iw" DFMT, + dims[0], dims[1], dims[2]); + else if (md->ndims == 4) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3]); + else if (md->ndims == 5) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3], dims[4]); + else + mkldnn_md2dim_str(str, len, md); +} + +void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind, + const char *impl_str, mkldnn_prop_kind_t prop_kind, + const char *data_str, const char *aux_str, const char *prb_str) { + MAYBE_UNUSED(verbose_templ); + int written = 0; + DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s", + mkldnn_prim_kind2str(prim_kind), impl_str, + mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str); +} + +template <typename pd_t> static void init_info_bnorm(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "flags:%u", s->desc()->flags); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_conv(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->ndims() == 5) { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } else { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_shuffle(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md(); + + if (1) { // data + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "axis:%d group_size:" DFMT, s->axis(), s->group_size()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_eltwise(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_iprod(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_lrn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_mem(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "num:%d", s->n_inputs()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_pool(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->is_fwd() ? s->src_md() : s->diff_src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // ws + auto md = s->workspace_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->is_3d()) { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "", + s->MB(), s->C(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } else { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT, + s->MB(), s->C(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_softmax(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_rnn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src layer + auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // src iter + auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_layer + auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_iter + auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bias + auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst layer + auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst iter + auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + alg_kind_t alg_kind = s->cell_kind(); + rnn_direction_t rnn_dir = s->direction(); + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s_%s", mkldnn_alg_kind2str(alg_kind), + mkldnn_rnn_direction2str(rnn_dir)); + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "l" DFMT "t" DFMT "mb" DFMT + "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT, + s->L(), s->T(), s->MB(), + s->SIC(), s->SLC(), s->DIC(), s->DLC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +#undef DPRINT + +#else // !defined(DISABLE_VERBOSE) + +#define DEFINE_STUB(name) \ + template <typename pd_t> \ + static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \ + { UNUSED(s); UNUSED(buffer); } + +DEFINE_STUB(bnorm); +DEFINE_STUB(conv); +DEFINE_STUB(eltwise); +DEFINE_STUB(iprod); +DEFINE_STUB(lrn); +DEFINE_STUB(mem); +DEFINE_STUB(pool); +DEFINE_STUB(softmax); +DEFINE_STUB(rnn); +DEFINE_STUB(shuffle); +#undef DEFINE_STUB + +#endif // !defined(DISABLE_VERBOSE) +} + +void init_info(batch_normalization_pd_t *s, char *b) +{ init_info_bnorm(s, b); } +void init_info(concat_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(convolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(deconvolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(eltwise_pd_t *s, char *b) +{ init_info_eltwise(s, b); } +void init_info(inner_product_pd_t *s, char *b) +{ init_info_iprod(s, b); } +void init_info(lrn_pd_t *s, char *b) +{ init_info_lrn(s, b); } +void init_info(pooling_pd_t *s, char *b) +{ init_info_pool(s, b); } +void init_info(reorder_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(rnn_pd_t *s, char *b) +{ init_info_rnn(s, b); } +void init_info(shuffle_pd_t *s, char *b) +{ init_info_shuffle(s, b); } +void init_info(softmax_pd_t *s, char *b) +{ init_info_softmax(s, b); } +void init_info(sum_pd_t *s, char *b) +{ init_info_mem(s, b); } + +} +} + +mkldnn_status_t mkldnn_set_verbose(int level) { + using namespace mkldnn::impl::status; + if (level < 0 || level > 2) return invalid_arguments; + mkldnn::impl::verbose.level = level; + mkldnn::impl::initialized = true; + return success; +} + +const mkldnn_version_t *mkldnn_version() { + static mkldnn_version_t ver = { + MKLDNN_VERSION_MAJOR, + MKLDNN_VERSION_MINOR, + MKLDNN_VERSION_PATCH, + MKLDNN_VERSION_HASH}; + return &ver; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp new file mode 100644 index 0000000000..e3049750cb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef VERBOSE_HPP +#define VERBOSE_HPP + +#include <stdio.h> +#include <cinttypes> + +#include "mkldnn_debug.h" +#include "c_types_map.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +struct verbose_t { + int level; +}; + +const verbose_t *mkldnn_verbose(); +double get_msec(); +const char *get_isa_info(); + +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_BUF_LEN 1024 +#else +#define MKLDNN_VERBOSE_BUF_LEN 1 +#endif + +void init_info(batch_normalization_pd_t *s, char *buffer); +void init_info(concat_pd_t *s, char *buffer); +void init_info(convolution_pd_t *s, char *buffer); +void init_info(deconvolution_pd_t *s, char *buffer); +void init_info(eltwise_pd_t *s, char *buffer); +void init_info(inner_product_pd_t *s, char *buffer); +void init_info(lrn_pd_t *s, char *buffer); +void init_info(pooling_pd_t *s, char *buffer); +void init_info(reorder_pd_t *s, char *buffer); +void init_info(rnn_pd_t *s, char *buffer); +void init_info(shuffle_pd_t *s, char *buffer); +void init_info(softmax_pd_t *s, char *buffer); +void init_info(sum_pd_t *s, char *buffer); + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp new file mode 100644 index 0000000000..520bd4710b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef Z_MAGIC_HPP +#define Z_MAGIC_HPP + +#define CHAIn2(a,b) a b +#define CHAIN2(a,b) CHAIn2(a,b) + +#define CONCAt2(a,b) a ## b +#define CONCAT2(a,b) CONCAt2(a,b) + +#define STRINGIFy(s) #s +#define STRINGIFY(s) STRINGIFy(s) + +#ifdef _MSC_VER +# define PRAGMA_MACRo(x) __pragma(x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#else +# define PRAGMA_MACRo(x) _Pragma(#x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#endif + +#define UNUSED(x) ((void)x) +#define MAYBE_UNUSED(x) UNUSED(x) + +#if defined(_WIN32) && !defined(__GNUC__) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |