summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/common
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/common')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp104
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp240
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp550
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/concat.cpp86
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp211
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution.cpp200
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp56
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp348
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp188
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp293
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp84
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp161
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/engine.cpp75
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/engine.hpp119
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp106
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp56
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp321
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/lrn.cpp91
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp170
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory.cpp238
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory.hpp63
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp212
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp400
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp295
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp365
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp277
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp77
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/nstl.hpp193
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/pooling.cpp114
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp238
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive.cpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive.hpp76
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp290
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp183
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp78
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp174
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp90
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp79
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/query.cpp59
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/reorder.cpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp85
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/rnn.cpp400
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp112
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp72
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp121
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/softmax.cpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp161
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/stream.cpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/stream.hpp44
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/sum.cpp79
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp143
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp200
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp348
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/utils.cpp135
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/utils.hpp370
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/verbose.cpp665
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/verbose.hpp62
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp46
65 files changed, 11287 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
new file mode 100644
index 0000000000..1a51d8562b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
@@ -0,0 +1,104 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
+ prop_kind_t prop_kind, const memory_desc_t *data_desc,
+ const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
+ bool args_ok = true
+ && !any_null(bnrm_desc, data_desc)
+ && one_of(prop_kind, forward_training, forward_inference,
+ backward_data, backward)
+ && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
+ if (!args_ok) return invalid_arguments;
+
+ auto bd = batch_normalization_desc_t();
+ bd.primitive_kind = primitive_kind::batch_normalization;
+ bd.prop_kind = prop_kind;
+
+ bd.data_desc = *data_desc;
+ bd.diff_data_desc = zero_md();
+ if ( one_of(bd.prop_kind,backward_data, backward) )
+ bd.diff_data_desc = *diff_data_desc;
+
+ dims_t scaleshift_dims = { 2, data_desc->dims[1] };
+ mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
+ scaleshift_dims, data_type::f32, mkldnn_nc);
+ bd.diff_data_scaleshift_desc = zero_md();
+ if (bd.prop_kind == backward) {
+ bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
+ }
+
+ dims_t stats_dims = { data_desc->dims[1] };
+ mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
+ data_type::f32, mkldnn_x);
+ bd.variance_desc = bd.mean_desc;
+ bd.batch_norm_epsilon = epsilon;
+
+ unsigned bnorm_flags =
+ mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
+ if ((~bnorm_flags & flags) != 0) return invalid_arguments;
+
+ bd.flags = flags;
+
+ bool consistency = true
+ && utils::one_of(bd.data_desc.ndims, 2, 4, 5);
+ if (bd.prop_kind == backward_data)
+ consistency = consistency
+ && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
+ && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
+ bd.diff_data_desc.ndims);
+ if (!consistency) return invalid_arguments;
+
+ *bnrm_desc = bd;
+ return success;
+}
+}
+
+status_t mkldnn_batch_normalization_forward_desc_init(
+ batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
+ const memory_desc_t *data_desc, float epsilon, unsigned flags) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
+ epsilon, flags);
+}
+
+status_t mkldnn_batch_normalization_backward_desc_init(
+ batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
+ const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
+ float epsilon, unsigned flags) {
+ if (!one_of(prop_kind, backward, backward_data))
+ return invalid_arguments;
+ return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
+ epsilon, flags);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
new file mode 100644
index 0000000000..f61410b33c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
@@ -0,0 +1,240 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef BATCH_NORMALIZATION_PD_HPP
+#define BATCH_NORMALIZATION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct batch_normalization_fwd_pd_t;
+
+struct batch_normalization_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::batch_normalization;
+
+ batch_normalization_pd_t(engine_t *engine,
+ const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ , stat_md_(desc_.mean_desc)
+ , scaleshift_md_(desc_.data_scaleshift_desc)
+ , ws_md_()
+ {}
+
+ const batch_normalization_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::batch_normalization_d:
+ *(const batch_normalization_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common batch_normalization aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return desc_.data_desc.ndims; }
+
+ bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
+ bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
+ bool use_global_stats() const
+ { return desc_.flags & mkldnn_use_global_stats; }
+ bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
+ bool with_relu_post_op() const {
+ const auto &p = this->attr()->post_ops_;
+ return p.len_ == 1 && p.entry_[0].is_relu(true, true);
+ }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+ bool is_bwd() const { return !this->is_fwd(); }
+ bool is_training() const
+ { return desc_.prop_kind == prop_kind::forward_training; }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+protected:
+ batch_normalization_desc_t desc_;
+ const batch_normalization_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+ memory_desc_t stat_md_;
+ memory_desc_t scaleshift_md_;
+
+ memory_desc_t ws_md_;
+
+ void init_default_ws(size_t bits_per_element) {
+ const auto data_mdw = memory_desc_wrapper(data_md_);
+
+ const dim_t data_nelems = data_mdw.nelems(true);
+ const dim_t bits_per_byte = 8;
+ const dims_t ws_sz = { (dim_t)utils::div_up(
+ data_nelems * bits_per_element, bits_per_byte) };
+ mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
+ format_tag::x);
+ }
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
+ typedef batch_normalization_fwd_pd_t base_class;
+ typedef batch_normalization_fwd_pd_t hint_class;
+
+ batch_normalization_fwd_pd_t(engine_t *engine,
+ const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
+ if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
+
+ if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
+ if (stats_is_src()) return arg_usage_t::input;
+ if (!stats_is_src() && is_training()) return arg_usage_t::output;
+ return arg_usage_t::unused;
+ }
+
+ if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override {
+ if (index == 0) return &data_md_;
+ if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
+ return nullptr;
+ }
+
+ virtual const memory_desc_t *dst_md(int index = 0) const override {
+ if (index == 0) return &data_md_;
+ if (!stats_is_src() && is_training() && (index == 1 || index == 2))
+ return &stat_md_;
+ return nullptr;
+ }
+
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &scaleshift_md_ : nullptr; }
+
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
+
+ const memory_desc_t *stat_md() const
+ { return stats_is_src() ? src_md(1) : dst_md(1); }
+
+ virtual int n_inputs() const override
+ { return 1 + 2 * stats_is_src() + use_scaleshift(); }
+ virtual int n_outputs() const override
+ { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
+};
+
+struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
+ typedef batch_normalization_bwd_pd_t base_class;
+ typedef batch_normalization_fwd_pd_t hint_class;
+
+ batch_normalization_bwd_pd_t(engine_t *engine,
+ const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_data_desc)
+ , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
+ MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &scaleshift_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override
+ { return index == 0 ? &diff_scaleshift_md_ : nullptr; }
+
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
+
+ const memory_desc_t *stat_md() const { return src_md(1); }
+
+ virtual int n_inputs() const override
+ { return 4 + use_scaleshift() + fuse_bn_relu(); }
+ virtual int n_outputs() const override
+ { return 1 + (desc_.prop_kind == prop_kind::backward); }
+
+protected:
+ memory_desc_t diff_data_md_;
+ memory_desc_t diff_scaleshift_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
new file mode 100644
index 0000000000..3d43a0fbee
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
@@ -0,0 +1,550 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TYPE_MAPPING_HPP
+#define TYPE_MAPPING_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+
+// TODO: autogenerate this
+
+using dim_t = mkldnn_dim_t;
+using dims_t = mkldnn_dims_t;
+using stride_t = mkldnn_dim_t;
+using strides_t = mkldnn_strides_t;
+
+using status_t = mkldnn_status_t;
+namespace status {
+ const status_t success = mkldnn_success;
+ const status_t out_of_memory = mkldnn_out_of_memory;
+ const status_t try_again = mkldnn_try_again;
+ const status_t invalid_arguments = mkldnn_invalid_arguments;
+ const status_t not_ready = mkldnn_not_ready;
+ const status_t unimplemented = mkldnn_unimplemented;
+ const status_t iterator_ends = mkldnn_iterator_ends;
+ const status_t runtime_error = mkldnn_runtime_error;
+ const status_t not_required = mkldnn_not_required;
+}
+
+using prop_kind_t = mkldnn_prop_kind_t;
+namespace prop_kind {
+ const prop_kind_t undef = mkldnn_prop_kind_undef;
+ const prop_kind_t forward_training = mkldnn_forward_training;
+ const prop_kind_t forward_inference = mkldnn_forward_inference;
+ const prop_kind_t forward_scoring = mkldnn_forward_scoring;
+ const prop_kind_t forward = mkldnn_forward;
+ const prop_kind_t backward = mkldnn_backward;
+ const prop_kind_t backward_data = mkldnn_backward_data;
+ const prop_kind_t backward_weights = mkldnn_backward_weights;
+ const prop_kind_t backward_bias = mkldnn_backward_bias;
+}
+
+using alg_kind_t = mkldnn_alg_kind_t;
+namespace alg_kind {
+ const alg_kind_t undef = mkldnn_alg_kind_undef;
+ const alg_kind_t convolution_auto = mkldnn_convolution_auto;
+ const alg_kind_t convolution_direct = mkldnn_convolution_direct;
+ const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
+ const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
+ const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
+ const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
+ const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
+ const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
+ const alg_kind_t eltwise_square = mkldnn_eltwise_square;
+ const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
+ const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
+ const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
+ const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
+ const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
+ const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
+ const alg_kind_t pooling_max = mkldnn_pooling_max;
+ const alg_kind_t pooling_avg = mkldnn_pooling_avg;
+ const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
+ const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
+ const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
+ const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
+ const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
+ const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
+ const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
+ const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
+}
+
+using data_type_t = mkldnn_data_type_t;
+namespace data_type {
+ const data_type_t undef = mkldnn_data_type_undef;
+ const data_type_t f32 = mkldnn_f32;
+ const data_type_t s32 = mkldnn_s32;
+ const data_type_t s8 = mkldnn_s8;
+ const data_type_t u8 = mkldnn_u8;
+}
+
+using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
+namespace scratchpad_mode {
+ const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
+ const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
+}
+
+using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
+namespace rnn_packed_format {
+ const rnn_packed_format_t undef = mkldnn_packed_format_undef;
+ const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
+ const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
+}
+
+using format_kind_t = mkldnn_format_kind_t;
+namespace format_kind {
+ const format_kind_t undef = mkldnn_format_kind_undef;
+ const format_kind_t any = mkldnn_format_kind_any;
+ const format_kind_t blocked = mkldnn_blocked;
+ const format_kind_t wino = mkldnn_format_kind_wino;
+ const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
+}
+
+using format_tag_t = mkldnn_format_tag_t;
+namespace format_tag {
+ const format_tag_t undef = mkldnn_format_tag_undef;
+ const format_tag_t any = mkldnn_format_tag_any;
+ const format_tag_t a = mkldnn_a;
+ const format_tag_t ab = mkldnn_ab;
+ const format_tag_t abc = mkldnn_abc;
+ const format_tag_t abcd = mkldnn_abcd;
+ const format_tag_t abcde = mkldnn_abcde;
+ const format_tag_t abcdef = mkldnn_abcdef;
+ const format_tag_t abdec = mkldnn_abdec;
+ const format_tag_t acb = mkldnn_acb;
+ const format_tag_t acbde = mkldnn_acbde;
+ const format_tag_t acdb = mkldnn_acdb;
+ const format_tag_t acdeb = mkldnn_acdeb;
+ const format_tag_t ba = mkldnn_ba;
+ const format_tag_t bac = mkldnn_bac;
+ const format_tag_t bacd = mkldnn_bacd;
+ const format_tag_t bcda = mkldnn_bcda;
+ const format_tag_t cba = mkldnn_cba;
+ const format_tag_t cdba = mkldnn_cdba;
+ const format_tag_t cdeba = mkldnn_cdeba;
+ const format_tag_t decab = mkldnn_decab;
+ const format_tag_t Abc16a = mkldnn_Abc16a;
+ const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
+ const format_tag_t aBc16b = mkldnn_aBc16b;
+ const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
+ const format_tag_t Abc4a = mkldnn_Abc4a;
+ const format_tag_t aBc4b = mkldnn_aBc4b;
+ const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
+ const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
+ const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
+ const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
+ const format_tag_t aBc8b = mkldnn_aBc8b;
+ const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
+ const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
+ const format_tag_t Abcd16a = mkldnn_Abcd16a;
+ const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
+ const format_tag_t aBcd16b = mkldnn_aBcd16b;
+ const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
+ const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
+ const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
+ const format_tag_t Abcd4a = mkldnn_Abcd4a;
+ const format_tag_t aBcd4b = mkldnn_aBcd4b;
+ const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
+ const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
+ const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
+ const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
+ const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
+ const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
+ const format_tag_t aBcd8b = mkldnn_aBcd8b;
+ const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
+ const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
+ const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
+ const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
+ const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
+ const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
+ const format_tag_t Abcde16a = mkldnn_Abcde16a;
+ const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
+ const format_tag_t aBcde16b = mkldnn_aBcde16b;
+ const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
+ const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
+ const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
+ const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
+ const format_tag_t Abcde4a = mkldnn_Abcde4a;
+ const format_tag_t aBcde4b = mkldnn_aBcde4b;
+ const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
+ const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
+ const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
+ const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
+ const format_tag_t Abcde8a = mkldnn_Abcde8a;
+ const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
+ const format_tag_t aBcde8b = mkldnn_aBcde8b;
+ const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
+ const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
+ const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
+ const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
+ const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
+ const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
+ const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
+ const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
+ const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
+ const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
+ const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
+ const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
+ const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
+ const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
+ const format_tag_t aBdc16b = mkldnn_aBdc16b;
+ const format_tag_t aBdc4b = mkldnn_aBdc4b;
+ const format_tag_t aBdc8b = mkldnn_aBdc8b;
+ const format_tag_t aBdec16b = mkldnn_aBdec16b;
+ const format_tag_t aBdec4b = mkldnn_aBdec4b;
+ const format_tag_t aBdec8b = mkldnn_aBdec8b;
+ const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
+ const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
+ const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
+ const format_tag_t Acb16a = mkldnn_Acb16a;
+ const format_tag_t Acb4a = mkldnn_Acb4a;
+ const format_tag_t Acb8a = mkldnn_Acb8a;
+ const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
+ const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
+ const format_tag_t Acdb16a = mkldnn_Acdb16a;
+ const format_tag_t Acdb4a = mkldnn_Acdb4a;
+ const format_tag_t Acdb8a = mkldnn_Acdb8a;
+ const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
+ const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
+ const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
+ const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
+ const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
+ const format_tag_t last = mkldnn_format_tag_last;
+
+ const format_tag_t x = mkldnn_x;
+ const format_tag_t nc = mkldnn_nc;
+ const format_tag_t cn = mkldnn_cn;
+ const format_tag_t ncw = mkldnn_ncw;
+ const format_tag_t nwc = mkldnn_nwc;
+ const format_tag_t nchw = mkldnn_nchw;
+ const format_tag_t nhwc = mkldnn_nhwc;
+ const format_tag_t chwn = mkldnn_chwn;
+ const format_tag_t ncdhw = mkldnn_ncdhw;
+ const format_tag_t ndhwc = mkldnn_ndhwc;
+ const format_tag_t oi = mkldnn_oi;
+ const format_tag_t io = mkldnn_io;
+ const format_tag_t oiw = mkldnn_oiw;
+ const format_tag_t wio = mkldnn_wio;
+ const format_tag_t oihw = mkldnn_oihw;
+ const format_tag_t hwio = mkldnn_hwio;
+ const format_tag_t ihwo = mkldnn_ihwo;
+ const format_tag_t iohw = mkldnn_iohw;
+ const format_tag_t oidhw = mkldnn_oidhw;
+ const format_tag_t dhwio = mkldnn_dhwio;
+ const format_tag_t goiw = mkldnn_goiw;
+ const format_tag_t goihw = mkldnn_goihw;
+ const format_tag_t hwigo = mkldnn_hwigo;
+ const format_tag_t giohw = mkldnn_giohw;
+ const format_tag_t goidhw = mkldnn_goidhw;
+ const format_tag_t tnc = mkldnn_tnc;
+ const format_tag_t ntc = mkldnn_ntc;
+ const format_tag_t ldsnc = mkldnn_ldsnc;
+ const format_tag_t ldigo = mkldnn_ldigo;
+ const format_tag_t ldgoi = mkldnn_ldgoi;
+ const format_tag_t ldgo = mkldnn_ldgo;
+ const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
+ const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
+ const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
+ const format_tag_t nChw16c = mkldnn_nChw16c;
+ const format_tag_t nChw4c = mkldnn_nChw4c;
+ const format_tag_t nChw8c = mkldnn_nChw8c;
+ const format_tag_t nCw16c = mkldnn_nCw16c;
+ const format_tag_t nCw4c = mkldnn_nCw4c;
+ const format_tag_t nCw8c = mkldnn_nCw8c;
+ const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
+ const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
+ const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
+ const format_tag_t Oiw16o = mkldnn_Oiw16o;
+ const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
+ const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
+ const format_tag_t Oiw4o = mkldnn_Oiw4o;
+ const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
+ const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
+ const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
+ const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
+ const format_tag_t Owi16o = mkldnn_Owi16o;
+ const format_tag_t Owi4o = mkldnn_Owi4o;
+ const format_tag_t Owi8o = mkldnn_Owi8o;
+ const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
+ const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
+ const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
+ const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
+ const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
+ const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
+ const format_tag_t Oihw16o = mkldnn_Oihw16o;
+ const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
+ const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
+ const format_tag_t Oihw4o = mkldnn_Oihw4o;
+ const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
+ const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
+ const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
+ const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
+ const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
+ const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
+ const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
+ const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
+ const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
+ const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
+ const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
+ const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
+ const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
+ const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
+ const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
+ const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
+ const format_tag_t Goiw16g = mkldnn_Goiw16g;
+ const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
+ const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
+ const format_tag_t gOiw16o = mkldnn_gOiw16o;
+ const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
+ const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
+ const format_tag_t gOiw4o = mkldnn_gOiw4o;
+ const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
+ const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
+ const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
+ const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
+ const format_tag_t gOwi16o = mkldnn_gOwi16o;
+ const format_tag_t gOwi4o = mkldnn_gOwi4o;
+ const format_tag_t gOwi8o = mkldnn_gOwi8o;
+ const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
+ const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
+ const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
+ const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
+ const format_tag_t Goihw16g = mkldnn_Goihw16g;
+ const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
+ const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
+ const format_tag_t gOihw16o = mkldnn_gOihw16o;
+ const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
+ const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
+ const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
+ const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
+ const format_tag_t gOihw4o = mkldnn_gOihw4o;
+ const format_tag_t Goihw8g = mkldnn_Goihw8g;
+ const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
+ const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
+ const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
+ const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
+ const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
+ const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
+ const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
+ const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
+ const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
+ const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
+ const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
+ const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
+ const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
+ const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
+ const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
+}
+
+using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
+namespace memory_extra_flags {
+ const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
+ const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
+ const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
+}
+
+using padding_kind_t = mkldnn_padding_kind_t;
+namespace padding_kind {
+ const padding_kind_t padding_zero = mkldnn_padding_zero;
+}
+
+using engine_kind_t = mkldnn_engine_kind_t;
+namespace engine_kind {
+ const engine_kind_t any_engine = mkldnn_any_engine;
+ const engine_kind_t cpu = mkldnn_cpu;
+}
+
+using primitive_kind_t = mkldnn_primitive_kind_t;
+namespace primitive_kind {
+ const primitive_kind_t undefined = mkldnn_undefined_primitive;
+ const primitive_kind_t reorder = mkldnn_reorder;
+ const primitive_kind_t concat = mkldnn_concat;
+ const primitive_kind_t sum = mkldnn_sum;
+ const primitive_kind_t convolution = mkldnn_convolution;
+ const primitive_kind_t deconvolution = mkldnn_deconvolution;
+ const primitive_kind_t shuffle = mkldnn_shuffle;
+ const primitive_kind_t eltwise = mkldnn_eltwise;
+ const primitive_kind_t softmax = mkldnn_softmax;
+ const primitive_kind_t pooling = mkldnn_pooling;
+ const primitive_kind_t lrn = mkldnn_lrn;
+ const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
+ const primitive_kind_t inner_product = mkldnn_inner_product;
+ const primitive_kind_t rnn = mkldnn_rnn;
+}
+
+using query_t = mkldnn_query_t;
+namespace query {
+ const query_t undef = mkldnn_query_undef;
+
+ const query_t engine = mkldnn_query_engine;
+ const query_t primitive_kind = mkldnn_query_primitive_kind;
+
+ const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
+ const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
+
+ const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
+ const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
+
+ const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
+
+ const query_t impl_info_str = mkldnn_query_impl_info_str;
+
+ const query_t some_d = mkldnn_query_some_d;
+ const query_t op_d = mkldnn_query_op_d;
+ const query_t convolution_d = mkldnn_query_convolution_d;
+ const query_t deconvolution_d = mkldnn_query_deconvolution_d;
+ const query_t shuffle_d = mkldnn_query_shuffle_d;
+ const query_t eltwise_d = mkldnn_query_eltwise_d;
+ const query_t softmax_d = mkldnn_query_softmax_d;
+ const query_t pooling_d = mkldnn_query_pooling_d;
+ const query_t lrn_d = mkldnn_query_lrn_d;
+ const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
+ const query_t inner_product_d = mkldnn_query_inner_product_d;
+ const query_t rnn_d = mkldnn_query_rnn_d;
+
+ const query_t some_md = mkldnn_query_some_md;
+ const query_t src_md = mkldnn_query_src_md;
+ const query_t diff_src_md = mkldnn_query_diff_src_md;
+ const query_t weights_md = mkldnn_query_weights_md;
+ const query_t diff_weights_md = mkldnn_query_diff_weights_md;
+ const query_t dst_md = mkldnn_query_dst_md;
+ const query_t diff_dst_md = mkldnn_query_diff_dst_md;
+
+ const query_t workspace_md = mkldnn_query_workspace_md;
+ const query_t scratchpad_md = mkldnn_query_scratchpad_md;
+}
+
+using blocking_desc_t = mkldnn_blocking_desc_t;
+using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
+using wino_desc_t = mkldnn_wino_desc_t;
+using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
+using memory_desc_t = mkldnn_memory_desc_t;
+using convolution_desc_t = mkldnn_convolution_desc_t;
+using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
+using shuffle_desc_t = mkldnn_shuffle_desc_t;
+using pooling_desc_t = mkldnn_pooling_desc_t;
+using eltwise_desc_t = mkldnn_eltwise_desc_t;
+using softmax_desc_t = mkldnn_softmax_desc_t;
+using lrn_desc_t = mkldnn_lrn_desc_t;
+using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
+using inner_product_desc_t = mkldnn_inner_product_desc_t;
+
+using rnn_direction_t = mkldnn_rnn_direction_t;
+using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
+using rnn_desc_t = mkldnn_rnn_desc_t;
+
+/* C op_desc_t, which eventually are just (void*) */
+using c_op_desc_t = mkldnn_op_desc_t;
+using const_c_op_desc_t = const_mkldnn_op_desc_t;
+
+struct op_desc_t {
+ union {
+ primitive_kind_t kind;
+ convolution_desc_t convolution;
+ deconvolution_desc_t deconvolution;
+ shuffle_desc_t shuffle;
+ pooling_desc_t pooling;
+ eltwise_desc_t eltwise;
+ softmax_desc_t softmax;
+ lrn_desc_t lrn;
+ batch_normalization_desc_t batch_normalization;
+ inner_product_desc_t inner_product;
+ rnn_desc_t rnn;
+ };
+
+ op_desc_t(const primitive_kind_t &_): kind(_) {}
+
+# define DECL_CTOR_AND_CONVERTERS(c_type, name) \
+ op_desc_t(const c_type &_): name(_) {} \
+ static op_desc_t *convert_from_c(c_type *_) \
+ { return reinterpret_cast<op_desc_t*>(_); } \
+ static const op_desc_t *convert_from_c(const c_type *_) \
+ { return reinterpret_cast<const op_desc_t*>(_); }
+
+ DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
+ DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
+ DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
+ DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
+ DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
+ DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
+ DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
+ DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
+ DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
+
+# undef DECL_CTOR_AND_CONVERTERS
+};
+
+using engine_t = mkldnn_engine;
+using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
+using primitive_desc_t = mkldnn_primitive_desc;
+using primitive_attr_t = mkldnn_primitive_attr;
+using post_ops_t = mkldnn_post_ops;
+using memory_t = mkldnn_memory;
+using primitive_t = mkldnn_primitive;
+
+using primitive_arg_index_t = int;
+
+using stream_flags_t = mkldnn_stream_flags_t;
+namespace stream_flags {
+ const stream_flags_t default_flags = mkldnn_stream_default_flags;
+}
+using stream_t = mkldnn_stream;
+
+/* forward declaration of the internal primitive_desc types */
+struct batch_normalization_bwd_pd_t;
+struct batch_normalization_fwd_pd_t;
+struct batch_normalization_pd_t;
+struct concat_pd_t;
+struct convolution_bwd_data_pd_t;
+struct convolution_bwd_weights_pd_t;
+struct convolution_fwd_pd_t;
+struct convolution_pd_t;
+struct deconvolution_bwd_data_pd_t;
+struct deconvolution_bwd_weights_pd_t;
+struct deconvolution_fwd_pd_t;
+struct deconvolution_pd_t;
+struct eltwise_bwd_pd_t;
+struct eltwise_fwd_pd_t;
+struct eltwise_pd_t;
+struct inner_product_bwd_data_pd_t;
+struct inner_product_bwd_weights_pd_t;
+struct inner_product_fwd_pd_t;
+struct inner_product_pd_t;
+struct lrn_bwd_pd_t;
+struct lrn_fwd_pd_t;
+struct lrn_pd_t;
+struct pooling_bwd_pd_t;
+struct pooling_fwd_pd_t;
+struct pooling_pd_t;
+struct reorder_pd_t;
+struct rnn_bwd_pd_t;
+struct rnn_fwd_pd_t;
+struct rnn_pd_t;
+struct shuffle_pd_t;
+struct softmax_bwd_pd_t;
+struct softmax_fwd_pd_t;
+struct softmax_pd_t;
+struct sum_pd_t;
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp
new file mode 100644
index 0000000000..ed4c35c6e9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp
@@ -0,0 +1,86 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "concat_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
+ const memory_desc_t *dst_md, int n, int concat_dim,
+ const memory_desc_t *src_mds,
+ const primitive_attr_t *attr,
+ engine_t *engine) {
+ bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
+ if (!args_ok) return invalid_arguments;
+
+ const primitive_attr_t dummy_attr;
+ if (attr == NULL)
+ attr = &dummy_attr;
+
+ const int ndims = src_mds[0].ndims;
+ const dims_t &dims = src_mds[0].dims;
+ const data_type_t dt = src_mds[0].data_type;
+
+ int concat_dim_sz = dims[concat_dim];
+ for (int i = 1; i < n; ++i) {
+ if (src_mds[i].ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (d == concat_dim) continue;
+ if (src_mds[i].dims[d] != dims[d])
+ return invalid_arguments;
+ }
+ if (src_mds[i].data_type != dt) return invalid_arguments;
+ concat_dim_sz += src_mds[i].dims[concat_dim];
+ }
+
+ memory_desc_t dummy_dst_md;
+ if (dst_md) {
+ if (dst_md->ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (dst_md->dims[d] !=
+ (d == concat_dim ? concat_dim_sz : dims[d]))
+ return invalid_arguments;
+ }
+ } else {
+ dummy_dst_md = src_mds[0];
+ dummy_dst_md.dims[concat_dim] = concat_dim_sz;
+ dummy_dst_md.format_kind = format_kind::any;
+ dst_md = &dummy_dst_md;
+ }
+
+ auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
+
+ for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
+ if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
+ == success) {
+ (*c_pd)->init_info();
+ (*c_pd)->init_scratchpad_md();
+ return success;
+ }
+ }
+ return unimplemented;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
new file mode 100644
index 0000000000..29311927e2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
@@ -0,0 +1,211 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CONCAT_PD_HPP
+#define CONCAT_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct concat_pd_t: public primitive_desc_t {
+ concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
+ const memory_desc_t *dst_md, int n, int concat_dim,
+ const memory_desc_t *src_mds)
+ : primitive_desc_t(engine, attr, primitive_kind::concat)
+ , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
+ {
+ src_mds_.reserve(n_);
+ for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
+ }
+
+ concat_pd_t(const concat_pd_t &rhs) = default;
+
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg >= MKLDNN_ARG_MULTIPLE_SRC
+ && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index < n_inputs() ? &src_mds_[index] : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return n_; }
+ virtual int n_outputs() const override { return 1; }
+
+ int concat_dim() const { return concat_dim_; }
+
+ const memory_desc_t *src_image_md(int index = 0) const
+ { return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
+
+protected:
+ int n_, concat_dim_;
+ memory_desc_t dst_md_;
+ nstl::vector<memory_desc_t> src_mds_;
+
+ /* contains images of srcs in the dst memory (if possible)
+ * Lives here to simplify some implementations. An implementation might
+ * use this auxiliary array iff init() returned success */
+ nstl::vector<memory_desc_t> src_image_mds_;
+
+protected:
+ /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper i_d(&src_mds_[i]);
+ if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
+ return status::unimplemented;
+ }
+
+ const int ndims = dst_md_.ndims;
+ int current_concat_dim_offset = 0;
+ for (int i = 0; i < n_; ++i) {
+ const int dim = src_mds_[i].dims[concat_dim_];
+ dims_t dims, offsets = {};
+ utils::array_copy(dims, dst_md_.dims, ndims);
+ dims[concat_dim_] = dim;
+ offsets[concat_dim_] = current_concat_dim_offset;
+
+ memory_desc_t src_img_d;
+ status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
+ &dst_md_, dims, offsets);
+ if (status != status::success) return status;
+ src_image_mds_.push_back(src_img_d);
+ current_concat_dim_offset += dim;
+ }
+
+ return status::success;
+ }
+
+ status_t set_default_params() {
+ if (dst_md_.format_kind != format_kind::any)
+ return status::success;
+
+ const int ndims = dst_md_.ndims;
+
+ /* The stupidest ever heuristics (but not the same as we had before):
+ * - Pick the first non-plain format;
+ * - If all formats are plain or it is not possible to create a
+ * blocked format for the output, pick the format of the plain input
+ * - If this fails as well, use plain layout (abcd...)
+ */
+ status_t status = status::unimplemented;
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(src_mds_[i]);
+ if (src_d.is_blocking_desc() && !src_d.is_plain()) {
+ status = memory_desc_init_by_blocking_desc(dst_md_,
+ src_d.blocking_desc());
+ if (status == status::success) break;
+ }
+ }
+
+ if (status == status::success) {
+ /* check if we can create a sub-memory for the dst */
+ bool desired_format_ok = true;
+ int current_concat_dim_offset = 0;
+ for (int i = 0; i < n_; ++i) {
+ const int dim = src_mds_[i].dims[concat_dim_];
+ dims_t dims, offsets = {};
+ utils::array_copy(dims, dst_md_.dims, ndims);
+ dims[concat_dim_] = dim;
+ offsets[concat_dim_] = current_concat_dim_offset;
+
+ memory_desc_t src_img_d;
+ status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
+ &dst_md_, dims, offsets);
+ if (status != status::success) {
+ desired_format_ok = false;
+ break;
+ }
+ current_concat_dim_offset += dim;
+ }
+
+ if (!desired_format_ok)
+ status = status::unimplemented;
+ }
+
+ /* if no success so far, try using the format of the first plain input */
+ if (status != status::success) {
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(src_mds_[i]);
+ if (src_d.is_blocking_desc() && src_d.is_plain()) {
+ status = memory_desc_init_by_blocking_desc(dst_md_,
+ memory_desc_wrapper(src_mds_[0]).blocking_desc());
+ if (status == status::success) return status;
+ }
+ }
+ }
+
+ /* the last line of defense: use plain abcd... format */
+ if (status != status::success)
+ status = memory_desc_init_by_strides(dst_md_, nullptr);
+
+ return status;
+ }
+};
+
+#define DECLARE_CONCAT_PD_t(impl_name, ...) \
+ static status_t create(concat_pd_t **concat_pd, \
+ engine_t *engine, const primitive_attr_t *attr, \
+ const memory_desc_t *dst_md, int n, int concat_dim, \
+ const memory_desc_t *src_mds) { \
+ using namespace status; \
+ auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
+ if (_pd == nullptr) return out_of_memory; \
+ if (_pd->init() != success) { delete _pd; return unimplemented; } \
+ return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
+ } \
+ virtual status_t create_primitive(primitive_t **p) const override { \
+ double ms = get_msec(); \
+ auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+ ms = get_msec() - ms; \
+ if (mkldnn_verbose()->level >= 2) { \
+ printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+ fflush(0); \
+ } \
+ return ret; \
+ } \
+ virtual pd_t *clone() const override { return new pd_t(*this); } \
+ virtual const char *name() const override { return impl_name; } \
+
+#define DECLARE_CONCAT_PD_T(impl_name, ...) \
+ DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
new file mode 100644
index 0000000000..0c5c02bcd1
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
@@ -0,0 +1,200 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace mkldnn {
+namespace impl {
+status_t conv_desc_init(convolution_desc_t *conv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t dilates,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ bool args_ok = true
+ && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
+ padding_l)
+ && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
+ && one_of(padding_kind, padding_kind::padding_zero);
+ if (!args_ok) return invalid_arguments;
+
+ if (padding_r == nullptr) padding_r = padding_l;
+
+ auto cd = convolution_desc_t();
+ cd.primitive_kind = primitive_kind::convolution;
+ cd.prop_kind = prop_kind;
+ cd.alg_kind = alg_kind;
+
+ cd.diff_src_desc = cd.src_desc = zero_md();
+ cd.diff_dst_desc = cd.dst_desc = zero_md();
+ cd.diff_weights_desc = cd.weights_desc = zero_md();
+ cd.diff_bias_desc = cd.bias_desc = zero_md();
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+ const bool with_bias =
+ bias_desc && bias_desc->format_kind != format_kind::undef;
+ const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
+
+ (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
+ (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
+ (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
+ *weights_desc;
+ if (with_bias)
+ (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
+ *bias_desc;
+
+ int sp_dims = src_desc->ndims - 2;
+ utils::array_copy(cd.strides, strides, sp_dims);
+ utils::array_copy(cd.padding[0], padding_l, sp_dims);
+ utils::array_copy(cd.padding[1], padding_r, sp_dims);
+ if (dilates)
+ utils::array_copy(cd.dilates, dilates, sp_dims);
+ else
+ utils::array_set(cd.dilates, 0, sp_dims);
+
+ cd.padding_kind = padding_kind;
+ cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+ weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+ const int g = with_groups ? weights_desc->dims[0] : 1;
+ const int bias_dim = prop_kind == backward_data
+ ? src_desc->dims[1]
+ : dst_desc->dims[1];
+
+ bool consistency = true
+ && memory_desc_wrapper(weights_desc).nelems()
+ && src_desc->ndims == dst_desc->ndims
+ && utils::one_of(src_desc->ndims, 3, 4, 5)
+ && utils::one_of(weights_desc->ndims, src_desc->ndims,
+ src_desc->ndims + 1)
+ && (with_bias ? bias_desc->ndims == 1 : true)
+ && (with_bias ? bias_desc->dims[0] == bias_dim : true)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
+ && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
+ for (int i = 2; i < src_desc->ndims; ++i)
+ {
+ int src = src_desc->dims[i];
+ int ker = weights_desc->dims[with_groups + i];
+ int dil = cd.dilates[i - 2];
+ int pad_l = padding_l[i - 2];
+ int pad_r = padding_r[i - 2];
+ int str = strides[i - 2];
+ int dst = dst_desc->dims[i];
+ int ker_range = 1 + (ker - 1) * (dil + 1);
+
+ if (str < 1) return invalid_arguments;
+ consistency = consistency
+ && dil >= 0
+ && pad_l >= 0
+ && pad_r + str > 0
+ && (src - ker_range + pad_l + pad_r) / str + 1 == dst;
+ }
+ if (!consistency) return invalid_arguments;
+
+ *conv_desc = cd;
+ return success;
+}
+}
+}
+
+status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, nullptr,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_forward_desc_init(
+ convolution_desc_t *conv_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, dilates,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_convolution_backward_data_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides, nullptr,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_backward_data_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides, dilates,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_convolution_backward_weights_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
+ nullptr, padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_backward_weights_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
+ dilates, padding_l, padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
new file mode 100644
index 0000000000..9604e0acf5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "utils.hpp"
+
+#include "convolution_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+using namespace prop_kind;
+
+memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
+ return desc->prop_kind == backward_data
+ ? &desc->diff_src_desc : &desc->src_desc;
+}
+
+memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_weights_desc : &desc->weights_desc;
+}
+
+memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_bias_desc : &desc->bias_desc;
+}
+
+memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
+ return utils::one_of(desc->prop_kind, forward_inference, forward_training)
+ ? &desc->dst_desc : &desc->diff_dst_desc;
+}
+
+const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
new file mode 100644
index 0000000000..b10c36db49
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
@@ -0,0 +1,348 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CONVOLUTION_PD_HPP
+#define CONVOLUTION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t conv_desc_init(convolution_desc_t *conv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t dilates,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind);
+
+memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
+
+struct convolution_fwd_pd_t;
+
+struct convolution_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::convolution;
+
+ convolution_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ {}
+
+ const convolution_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case pkind_traits<base_pkind>::query_d:
+ *(const convolution_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common conv aux functions */
+
+ dim_t MB() const { return _src_md()->dims[0]; }
+
+ dim_t IC() const { return _src_md()->dims[1]; }
+ dim_t OC() const { return _dst_md()->dims[1]; }
+ dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
+
+ dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
+ dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
+ dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
+
+ dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
+ dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
+ dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
+
+ dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
+ dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
+ dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
+
+ dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+ dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+ dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+ dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
+ dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
+ dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
+
+ dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+ dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+ dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+ dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+ dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+ dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+ int ndims() const { return _src_md()->ndims; }
+
+ bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
+ bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ bool has_zero_dim_memory() const {
+ const auto s_d = memory_desc_wrapper(*_src_md());
+ const auto d_d = memory_desc_wrapper(*_dst_md());
+ return s_d.has_zero_dim() || d_d.has_zero_dim();
+ }
+
+protected:
+ convolution_desc_t desc_;
+ const convolution_fwd_pd_t *hint_fwd_pd_;
+
+ bool set_default_formats_common_template(
+ memory_desc_t &src_md, format_tag_t src_tag,
+ memory_desc_t &wei_md, format_tag_t wei_tag,
+ memory_desc_t &dst_md, format_tag_t dst_tag,
+ memory_desc_t &bia_md) {
+ using namespace format_tag;
+
+# define IS_OK(f) \
+ do { if ((f) != status::success) return false; } while(0)
+ if (src_md.format_kind == format_kind::any
+ && !utils::one_of(src_tag, any, undef))
+ IS_OK(memory_desc_init_by_tag(src_md, src_tag));
+ if (dst_md.format_kind == format_kind::any
+ && !utils::one_of(dst_tag, any, undef))
+ IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
+ if (wei_md.format_kind == format_kind::any
+ && !utils::one_of(wei_tag, any, undef))
+ IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
+ if (with_bias() && bia_md.format_kind == format_kind::any)
+ IS_OK(memory_desc_init_by_tag(bia_md, x));
+# undef IS_OK
+
+ return true;
+ }
+
+ bool set_default_alg_kind(alg_kind_t alg_kind) {
+ assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
+ alg_kind::convolution_winograd));
+ if (desc_.alg_kind == alg_kind::convolution_auto)
+ desc_.alg_kind = alg_kind;
+ return desc_.alg_kind == alg_kind;
+ }
+
+ bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
+ data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
+ bool ok = true
+ && (src_dt == data_type::undef || _src_md()->data_type == src_dt)
+ && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
+ && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
+ && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
+ if (with_bias() && bia_dt != data_type::undef)
+ ok = ok && _bia_md()->data_type == bia_dt;
+ return ok;
+ }
+
+private:
+ const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
+ const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
+ const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
+ const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
+};
+
+struct convolution_fwd_pd_t: public convolution_pd_t {
+ typedef convolution_fwd_pd_t base_class;
+ typedef convolution_fwd_pd_t hint_class;
+
+ convolution_fwd_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_md_;
+
+ bool set_default_formats_common(format_tag_t src_tag,
+ format_tag_t wei_tag, format_tag_t dst_tag) {
+ return set_default_formats_common_template(src_md_, src_tag,
+ weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
+ }
+};
+
+struct convolution_bwd_data_pd_t: public convolution_pd_t {
+ typedef convolution_bwd_data_pd_t base_class;
+ typedef convolution_fwd_pd_t hint_class;
+
+ convolution_bwd_data_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+ virtual bool support_bias() const { return false; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t diff_dst_md_;
+
+ bool set_default_formats_common(format_tag_t diff_src_tag,
+ format_tag_t wei_tag, format_tag_t diff_dst_tag) {
+ return set_default_formats_common_template(diff_src_md_, diff_src_tag,
+ weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
+ }
+};
+
+struct convolution_bwd_weights_pd_t: public convolution_pd_t {
+ typedef convolution_bwd_weights_pd_t base_class;
+ typedef convolution_fwd_pd_t hint_class;
+
+ convolution_bwd_weights_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , diff_weights_md_(desc_.diff_weights_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+ if (index == 0) return &diff_weights_md_;
+ if (index == 1 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t diff_weights_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_md_;
+
+ bool set_default_formats_common(format_tag_t src_tag,
+ format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
+ return set_default_formats_common_template(src_md_, src_tag,
+ diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
+ diff_bias_md_);
+ }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
new file mode 100644
index 0000000000..98063c1c37
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
@@ -0,0 +1,188 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t dilates, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ bool args_ok = true
+ && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
+ padding_l)
+ && one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
+ && one_of(padding_kind, padding_kind::padding_zero);
+ if (!args_ok)
+ return invalid_arguments;
+
+ if (padding_r == nullptr)
+ padding_r = padding_l;
+
+ auto dd = deconvolution_desc_t();
+ dd.primitive_kind = primitive_kind::deconvolution;
+ dd.prop_kind = prop_kind;
+ dd.alg_kind = alg_kind;
+
+ dd.diff_src_desc = dd.src_desc = zero_md();
+ dd.diff_dst_desc = dd.dst_desc = zero_md();
+ dd.diff_weights_desc = dd.weights_desc = zero_md();
+ dd.diff_bias_desc = dd.bias_desc = zero_md();
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+ const bool with_bias
+ = bias_desc && bias_desc->format_kind != format_kind::undef;
+ const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
+
+ (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
+ (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
+ (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
+ = *weights_desc;
+ if (with_bias)
+ (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
+ = *bias_desc;
+
+ int sp_dims = src_desc->ndims - 2;
+ utils::array_copy(dd.strides, strides, sp_dims);
+ utils::array_copy(dd.padding[0], padding_l, sp_dims);
+ utils::array_copy(dd.padding[1], padding_r, sp_dims);
+ if (dilates)
+ utils::array_copy(dd.dilates, dilates, sp_dims);
+ else
+ utils::array_set(dd.dilates, 0, sp_dims);
+
+ dd.padding_kind = padding_kind;
+ dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+ weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+ const int g = with_groups ? weights_desc->dims[0] : 1;
+ bool consistency = true
+ && src_desc->ndims == dst_desc->ndims
+ && utils::one_of(src_desc->ndims, 3, 4, 5)
+ && utils::one_of(weights_desc->ndims, src_desc->ndims,
+ src_desc->ndims + 1)
+ && (with_bias ? bias_desc->ndims == 1 : true)
+ && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
+ && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
+ for (int i = 2; i < src_desc->ndims; ++i) {
+ int src = src_desc->dims[i];
+ int ker = weights_desc->dims[with_groups + i];
+ int dil = dd.dilates[i - 2];
+ int pad = padding_l[i - 2] + padding_r[i - 2];
+ int str = strides[i - 2];
+ int dst = dst_desc->dims[i];
+ int ker_range = 1 + (ker - 1) * (dil + 1);
+
+ consistency
+ = consistency && (dst - ker_range + pad) / str + 1 == src;
+ }
+ if (!consistency)
+ return invalid_arguments;
+
+ *deconv_desc = dd;
+ return success;
+}
+}
+
+status_t mkldnn_deconvolution_forward_desc_init(
+ deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_forward_desc_init(
+ deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_deconvolution_backward_data_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_deconvolution_backward_weights_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
+ const dims_t strides, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
+ const dims_t strides, const dims_t dilates, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
+ padding_l, padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
new file mode 100644
index 0000000000..539e44bd9b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
@@ -0,0 +1,293 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef DECONVOLUTION_PD_HPP
+#define DECONVOLUTION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "convolution_pd.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct deconvolution_fwd_pd_t;
+
+struct deconvolution_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::deconvolution;
+
+ deconvolution_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ {}
+
+ const deconvolution_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case pkind_traits<base_pkind>::query_d:
+ *(const deconvolution_desc_t **)result = desc();
+ break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
+
+ dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
+
+ dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
+ dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
+ dim_t G() const
+ { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
+
+ dim_t ID() const {
+ return ndims() >= 5
+ ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t IH() const {
+ return ndims() >= 4
+ ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t IW() const {
+ return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
+ }
+
+ dim_t OD() const {
+ return ndims() >= 5
+ ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t OH() const {
+ return ndims() >= 4
+ ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t OW() const {
+ return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
+ }
+
+ dim_t KD() const {
+ const int w_ndims = ndims() + with_groups();
+ return ndims() >= 5
+ ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
+ }
+ dim_t KH() const {
+ const int w_ndims = ndims() + with_groups();
+ return ndims() >= 4
+ ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
+ }
+ dim_t KW() const {
+ const int w_ndims = ndims() + with_groups();
+ return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
+ }
+
+ dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+ dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+ dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+ dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
+ dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
+ dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
+
+ dim_t padFront() const
+ { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+ dim_t padBack() const
+ { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+ dim_t padT() const
+ { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+ dim_t padB() const
+ { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+ dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+ dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+ bool with_bias() const {
+ return
+ !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
+ }
+
+ bool with_groups() const
+ { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
+
+ int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ bool has_zero_dim_memory() const {
+ const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
+ const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
+ return s_d.has_zero_dim() || d_d.has_zero_dim();
+ }
+
+protected:
+ deconvolution_desc_t desc_;
+ const deconvolution_fwd_pd_t *hint_fwd_pd_;
+};
+
+struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
+ typedef deconvolution_fwd_pd_t base_class;
+ typedef deconvolution_fwd_pd_t hint_class;
+
+ deconvolution_fwd_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_md_;
+};
+
+struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
+ typedef deconvolution_bwd_data_pd_t base_class;
+ typedef deconvolution_fwd_pd_t hint_class;
+
+ deconvolution_bwd_data_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , weights_md_(desc_.weights_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &weights_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t diff_dst_md_;
+};
+
+struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
+ typedef deconvolution_bwd_weights_pd_t base_class;
+ typedef deconvolution_fwd_pd_t hint_class;
+
+ deconvolution_bwd_weights_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , diff_weights_md_(desc_.diff_weights_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+ if (index == 0) return &diff_weights_md_;
+ if (index == 1 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t diff_weights_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
new file mode 100644
index 0000000000..f1708fca52
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
@@ -0,0 +1,84 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *data_desc,
+ const memory_desc_t *diff_data_desc, float alpha, float beta) {
+ bool args_ok = true
+ && !any_null(eltwise_desc, data_desc)
+ && one_of(prop_kind, forward_training, forward_inference,
+ backward_data)
+ && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
+ && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
+ if (!args_ok) return invalid_arguments;
+
+ auto ed = eltwise_desc_t();
+ ed.primitive_kind = primitive_kind::eltwise;
+ ed.prop_kind = prop_kind;
+ ed.alg_kind = alg_kind;
+
+ ed.data_desc = *data_desc;
+ ed.diff_data_desc =
+ (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
+
+ ed.alpha = alpha;
+ ed.beta = beta;
+
+ bool consistency = true
+ && IMPLICATION(ed.prop_kind == backward_data,
+ array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
+ ed.diff_data_desc.ndims));
+ if (!consistency) return invalid_arguments;
+
+ *eltwise_desc = ed;
+ return success;
+}
+}
+
+status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *data_desc, float alpha, float beta) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
+ nullptr, alpha, beta);
+}
+
+status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
+ alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
+ const memory_desc_t *data_desc, float alpha, float beta) {
+ return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
+ diff_data_desc, alpha, beta);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
new file mode 100644
index 0000000000..9fd260fcee
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
@@ -0,0 +1,161 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef ELTWISE_PD_HPP
+#define ELTWISE_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct eltwise_fwd_pd_t;
+
+struct eltwise_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::eltwise;
+
+ eltwise_pd_t(mkldnn::impl::engine_t *engine,
+ const eltwise_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const eltwise_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ {}
+
+ const eltwise_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::eltwise_d:
+ *(const eltwise_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common eltwise aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_desc().ndims; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+protected:
+ eltwise_desc_t desc_;
+ const eltwise_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct eltwise_fwd_pd_t: public eltwise_pd_t {
+ typedef eltwise_fwd_pd_t base_class;
+ typedef eltwise_fwd_pd_t hint_class;
+
+ eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
+ const eltwise_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const eltwise_fwd_pd_t *hint_fwd_pd)
+ : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override { return 1; }
+
+ bool is_zero_preserved() const
+ { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
+};
+
+struct eltwise_bwd_pd_t: public eltwise_pd_t {
+ typedef eltwise_bwd_pd_t base_class;
+ typedef eltwise_fwd_pd_t hint_class;
+
+ eltwise_bwd_pd_t(engine_t *engine,
+ const eltwise_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const eltwise_fwd_pd_t *hint_fwd_pd)
+ : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_data_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1; }
+
+ bool is_zero_preserved() const { return true; }
+
+protected:
+ memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp
new file mode 100644
index 0000000000..3b3e25456d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp
@@ -0,0 +1,75 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+#include "engine.hpp"
+#include "nstl.hpp"
+
+#include "c_types_map.hpp"
+#include "../cpu/cpu_engine.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+engine_factory_t *engine_factories[] = {
+ &cpu::engine_factory,
+ nullptr,
+};
+
+static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
+ for (engine_factory_t **ef = engine_factories; *ef; ef++)
+ if ((*ef)->kind() == kind)
+ return *ef;
+ return nullptr;
+}
+
+}
+}
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+size_t mkldnn_engine_get_count(engine_kind_t kind) {
+ engine_factory_t *ef = get_engine_factory(kind);
+ return ef != nullptr ? ef->count() : 0;
+}
+
+status_t mkldnn_engine_create(engine_t **engine,
+ engine_kind_t kind, size_t index) {
+ if (engine == nullptr)
+ return invalid_arguments;
+
+ engine_factory_t *ef = get_engine_factory(kind);
+ if (ef == nullptr || index >= ef->count())
+ return invalid_arguments;
+
+ return ef->engine_create(engine, index);
+}
+
+status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
+ if (engine == nullptr)
+ return invalid_arguments;
+ *kind = engine->kind();
+ return success;
+}
+
+status_t mkldnn_engine_destroy(engine_t *engine) {
+ /* TODO: engine->dec_ref_count(); */
+ delete engine;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp
new file mode 100644
index 0000000000..8ac8a29de5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp
@@ -0,0 +1,119 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef ENGINE_HPP
+#define ENGINE_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive.hpp"
+#include "utils.hpp"
+
+/** \brief An abstraction of an execution unit with shared resources
+ *
+ * Responsibilities:
+ * - Provide engine specific memory allocation
+ * - Provide engine specific primitive_desc_t creators
+ */
+struct mkldnn_engine: public mkldnn::impl::c_compatible {
+ mkldnn_engine(mkldnn::impl::engine_kind_t kind)
+ : kind_(kind)
+ {}
+ virtual ~mkldnn_engine() {}
+
+ /** get kind of the current engine */
+ virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
+
+ /** allocate memory */
+ virtual mkldnn::impl::status_t memory_create(
+ mkldnn::impl::memory_t **memory,
+ const mkldnn::impl::memory_desc_t *md,
+ void *handle) = 0;
+
+ /** implementation section (typedefs) */
+
+ // TODO: remove engine?
+ typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
+ mkldnn::impl::reorder_pd_t **reorder_pd,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::engine_t *src_engine,
+ const mkldnn::impl::memory_desc_t *src_md,
+ mkldnn::impl::engine_t *dst_engine,
+ const mkldnn::impl::memory_desc_t *dst_md);
+
+ typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
+ mkldnn::impl::concat_pd_t **concat_pd,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ const mkldnn::impl::memory_desc_t *dst_md,
+ int n, int concat_dim,
+ const mkldnn::impl::memory_desc_t *src_mds);
+
+ typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
+ mkldnn::impl::sum_pd_t **sum_pd,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ const mkldnn::impl::memory_desc_t *dst_md,
+ int n, const float *scales,
+ const mkldnn::impl::memory_desc_t *src_mds);
+
+ typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
+ mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
+
+ /* implementation section */
+
+ /** return the list of reorder implementations. engine guarantees to return
+ * a NULL-terminated list */
+ virtual const reorder_primitive_desc_create_f*
+ get_reorder_implementation_list() const = 0;
+
+ /** return the list of concat implementations. engine guarantees to return
+ * a NULL-terminated list */
+ virtual const concat_primitive_desc_create_f*
+ get_concat_implementation_list() const = 0;
+
+ /** return the list of sum implementations. engine guarantees to return
+ * a NULL-terminated list */
+ virtual const sum_primitive_desc_create_f*
+ get_sum_implementation_list() const = 0;
+
+ /** return the list of implementations. engine guarantees to return a
+ * NULL-terminated list */
+ virtual const primitive_desc_create_f* get_implementation_list() const = 0;
+
+protected:
+ mkldnn::impl::engine_kind_t kind_;
+};
+
+namespace mkldnn {
+namespace impl {
+
+struct engine_factory_t: public c_compatible {
+ virtual size_t count() const = 0;
+ virtual engine_kind_t kind() const = 0;
+ virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
new file mode 100644
index 0000000000..5a9f58cb1e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
@@ -0,0 +1,106 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
+ bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
+ if (!args_ok) return invalid_arguments;
+
+ auto id = inner_product_desc_t();
+ id.primitive_kind = primitive_kind::inner_product;
+ id.prop_kind = prop_kind;
+
+ id.diff_src_desc = id.src_desc = zero_md();
+ id.diff_dst_desc = id.dst_desc = zero_md();
+ id.diff_weights_desc = id.weights_desc = zero_md();
+ id.diff_bias_desc = id.bias_desc = zero_md();
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+ const bool with_bias =
+ bias_desc && bias_desc->format_kind != format_kind::undef;
+
+ (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
+ (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
+ (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
+ *weights_desc;
+ if (with_bias)
+ (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
+ *bias_desc;
+
+ id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+ weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+ bool consistency = true
+ && memory_desc_wrapper(weights_desc).nelems()
+ && one_of(src_desc->ndims, 2, 3, 4, 5)
+ && dst_desc->ndims == 2
+ && weights_desc->ndims == src_desc->ndims
+ && (with_bias ? bias_desc->ndims == 1 : true)
+ && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
+ src_desc->ndims - 1)
+ && dst_desc->dims[1] == weights_desc->dims[0];
+ if (!consistency) return invalid_arguments;
+
+ *ip_desc = id;
+ return success;
+}
+}
+
+status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
+ prop_kind_t prop_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
+ dst_desc);
+}
+
+status_t mkldnn_inner_product_backward_data_desc_init(
+ inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
+{
+ return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
+ nullptr, diff_dst_desc);
+}
+
+status_t mkldnn_inner_product_backward_weights_desc_init(
+ inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
+ const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_desc) {
+ return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
+ diff_bias_desc, diff_dst_desc);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
new file mode 100644
index 0000000000..091cf0f5d6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "utils.hpp"
+
+#include "inner_product_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+using namespace prop_kind;
+
+memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
+ return desc->prop_kind == backward_data
+ ? &desc->diff_src_desc : &desc->src_desc;
+}
+
+memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_weights_desc : &desc->weights_desc;
+}
+
+memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_bias_desc : &desc->bias_desc;
+}
+
+memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
+ return utils::one_of(desc->prop_kind, forward_inference, forward_training)
+ ? &desc->dst_desc : &desc->diff_dst_desc;
+}
+
+const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
new file mode 100644
index 0000000000..c426de632c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
@@ -0,0 +1,321 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef INNER_PRODUCT_PD_HPP
+#define INNER_PRODUCT_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
+
+struct inner_product_fwd_pd_t;
+
+struct inner_product_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::inner_product;
+
+ inner_product_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ {}
+
+ const inner_product_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::inner_product_d:
+ *(const inner_product_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common inner_product aux functions */
+
+ dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
+ dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
+ dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
+
+ dim_t ID() const {
+ return ndims() >= 5
+ ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t IH() const {
+ return ndims() >= 4
+ ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t IW() const {
+ return ndims() >= 3
+ ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
+ }
+
+ dim_t OD() const {
+ return ndims() >= 5
+ ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t OH() const {
+ return ndims() >= 4
+ ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t OW() const {
+ return ndims() >= 3
+ ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
+ }
+
+ dim_t KD() const {
+ return ndims() >= 5
+ ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t KH() const {
+ return ndims() >= 4
+ ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t KW() const {
+ return ndims() >= 3
+ ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
+ }
+
+ dim_t IC_total() const {
+ return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
+ ndims() - 1);
+ }
+
+ dim_t IC_total_padded() const {
+ auto src_d = desc()->prop_kind == prop_kind::backward_data
+ ? memory_desc_wrapper(diff_src_md())
+ : memory_desc_wrapper(src_md());
+ assert(src_d.is_blocking_desc());
+ if (!src_d.is_blocking_desc()) return -1;
+ return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
+ }
+
+ int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
+
+ bool with_bias() const
+ { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
+
+ bool has_zero_dim_memory() const {
+ const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
+ const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
+ return s_d.has_zero_dim() || d_d.has_zero_dim();
+ }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ inner_product_desc_t desc_;
+ const inner_product_fwd_pd_t *hint_fwd_pd_;
+
+ status_t template_set_default_params(memory_desc_t &src_md,
+ memory_desc_t &weights_md, memory_desc_t &dst_md,
+ memory_desc_t *bias_md) {
+ using namespace format_tag;
+ if (src_md.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md,
+ utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
+ }
+ if (dst_md.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_md, nc));
+ if (weights_md.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(weights_md,
+ utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
+ }
+ if (bias_md && bias_md->format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(*bias_md, x));
+ return status::success;
+ }
+};
+
+struct inner_product_fwd_pd_t: public inner_product_pd_t {
+ typedef inner_product_fwd_pd_t base_class;
+ typedef inner_product_fwd_pd_t hint_class;
+
+ inner_product_fwd_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_md_;
+
+ status_t set_default_params() {
+ return template_set_default_params(src_md_, weights_md_, dst_md_,
+ &bias_md_);
+ }
+};
+
+struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
+ typedef inner_product_bwd_data_pd_t base_class;
+ typedef inner_product_fwd_pd_t hint_class;
+
+ inner_product_bwd_data_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , weights_md_(desc_.weights_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &weights_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t diff_dst_md_;
+
+ status_t set_default_params() {
+ return template_set_default_params(diff_src_md_, weights_md_,
+ diff_dst_md_, nullptr);
+ }
+};
+
+struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
+ typedef inner_product_bwd_weights_pd_t base_class;
+ typedef inner_product_fwd_pd_t hint_class;
+
+ inner_product_bwd_weights_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , diff_weights_md_(desc_.diff_weights_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+ if (index == 0) return &diff_weights_md_;
+ if (index == 1 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t diff_weights_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_md_;
+
+ status_t set_default_params() {
+ return template_set_default_params(src_md_, diff_weights_md_,
+ diff_dst_md_, &diff_bias_md_);
+ }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
new file mode 100644
index 0000000000..fcf18b556f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
@@ -0,0 +1,91 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t lrn_desc_init(lrn_desc_t *lrn_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
+ dim_t local_size, float alpha, float beta, float k) {
+ bool args_ok = true
+ && !any_null(lrn_desc, data_desc)
+ && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
+ && one_of(prop_kind, forward_training, forward_inference, backward_data)
+ && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
+ if (!args_ok) return invalid_arguments;
+
+ auto ld = lrn_desc_t();
+ ld.primitive_kind = primitive_kind::lrn;
+ ld.prop_kind = prop_kind;
+ ld.alg_kind = alg_kind;
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+
+ ld.data_desc = *data_desc;
+ if (!is_fwd)
+ ld.diff_data_desc = *diff_data_desc;
+ else
+ ld.diff_data_desc = zero_md();
+ ld.local_size = local_size;
+ ld.lrn_alpha = alpha;
+ ld.lrn_beta = beta;
+ ld.lrn_k = k;
+
+ bool consistency = true
+ && ld.data_desc.ndims == 4;
+ if (ld.prop_kind == backward_data)
+ consistency = consistency
+ && ld.diff_data_desc.ndims == 4
+ && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
+ if (!consistency) return invalid_arguments;
+
+ *lrn_desc = ld;
+ return success;
+}
+}
+
+status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *data_desc, dim_t local_size, float alpha,
+ float beta, float k) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
+ local_size, alpha, beta, k);
+}
+
+status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
+ alg_kind_t alg_kind, const memory_desc_t *data_desc,
+ const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
+ float beta, float k) {
+ return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
+ diff_data_desc, local_size, alpha, beta, k);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
new file mode 100644
index 0000000000..90886e9656
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
@@ -0,0 +1,170 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef LRN_PD_HPP
+#define LRN_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct lrn_fwd_pd_t;
+
+struct lrn_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::lrn;
+
+ lrn_pd_t(engine_t *engine,
+ const lrn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const lrn_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ , ws_md_()
+ {}
+
+ const lrn_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::lrn_d:
+ *(const lrn_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common lrn aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_desc().ndims; }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ lrn_desc_t desc_;
+ const lrn_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+ memory_desc_t ws_md_;
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct lrn_fwd_pd_t: public lrn_pd_t {
+ typedef lrn_fwd_pd_t base_class;
+ typedef lrn_fwd_pd_t hint_class;
+
+ lrn_fwd_pd_t(engine_t *engine,
+ const lrn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const lrn_fwd_pd_t *hint_fwd_pd)
+ : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+};
+
+struct lrn_bwd_pd_t: public lrn_pd_t {
+ typedef lrn_bwd_pd_t base_class;
+ typedef lrn_fwd_pd_t hint_class;
+
+ lrn_bwd_pd_t(engine_t *engine,
+ const lrn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const lrn_fwd_pd_t *hint_fwd_pd)
+ : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_data_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::input;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override
+ { return 2 + (workspace_md() != nullptr); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
new file mode 100644
index 0000000000..3fddc0bd45
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MATH_UTILS_HPP
+#define MATH_UTILS_HPP
+
+#include <stdint.h>
+#include <math.h>
+
+#include "utils.hpp"
+#include "nstl.hpp"
+#include "mkldnn_traits.hpp"
+
+#if defined(MKLDNN_X86_64)
+#include "immintrin.h"
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace math {
+
+/** rounds @p f to an integer according to the mxcsr register */
+inline int mxcsr_round(float f) {
+#if defined(MKLDNN_X86_64)
+ return _mm_cvtss_si32(_mm_load_ss(&f));
+#else
+ return (int)nearbyintf(f); // optimism
+#endif
+}
+
+template <typename data_t, typename acc_t>
+inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
+ typename utils::remove_reference<data_t>::type>::type
+saturate(const acc_t &x) {
+ return (typename utils::remove_reference<data_t>::type)x;
+}
+
+template <typename data_t, typename acc_t>
+inline typename utils::enable_if<nstl::is_integral<data_t>::value,
+ typename utils::remove_reference<data_t>::type>::type
+saturate(const acc_t &x) {
+ acc_t v = x;
+ if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
+ v = (acc_t)nstl::numeric_limits<data_t>::lowest();
+ if (v > (acc_t)nstl::numeric_limits<data_t>::max())
+ v = (acc_t)nstl::numeric_limits<data_t>::max();
+ return (typename utils::remove_reference<data_t>::type)v;
+}
+
+template <typename data_t>
+double saturate(const double &x) {
+ double v = x;
+ if (v < (double)nstl::numeric_limits<data_t>::lowest())
+ v = (double)nstl::numeric_limits<data_t>::lowest();
+ if (v > (double)nstl::numeric_limits<data_t>::max())
+ v = (double)nstl::numeric_limits<data_t>::max();
+ return v;
+}
+
+template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
+ return x <= 127u ? x : 127;
+}
+
+template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
+ return x >= 0 ? x : 0;
+}
+
+template <typename out_t>
+typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
+out_round(float v) { return (out_t)mxcsr_round(v); }
+
+template <typename out_t>
+typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
+out_round(double v) { return (out_t)mxcsr_round((float)v); }
+
+template <typename out_t>
+typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
+out_round(float v) { return v; }
+
+inline int gcd(int a, int b) {
+ a = impl::nstl::abs(a);
+ b = impl::nstl::abs(b);
+ if (a < b) { int x = a; a = b; b = x; }
+
+ if (b == 0) return a;
+
+ int r;
+ while ((r = a % b) != 0) { a = b; b = r; }
+
+ return b;
+}
+
+template <typename T>
+inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
+
+/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
+inline int ilog2q(size_t v) {
+ if (v == 0)
+ return -1;
+
+ int p = 0;
+# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
+ CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
+# undef CP
+ return p;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U one_m_square(T x) {
+ return (U)(1 - x) * (1 + x);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U x_m_square(T x) {
+ return (U)(1 - x) * x;
+}
+
+/* activation */
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U relu_fwd(T s, A alpha) {
+ return s > 0 ? s : (U)(s * alpha);
+}
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U relu_bwd(T dd, T s, A alpha) {
+ return s > 0 ? dd : (U)(dd * alpha);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_fwd(T s) {
+ const float e = tanhf((float) s);
+ return (U)e;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_bwd(T dd, T s) {
+ const float e = tanh_fwd<float>((float) s);
+ return (U)(dd * (1 - e) * (1 + e));
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U elu_fwd(T s, A alpha) {
+ return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
+}
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+ inline U elu_bwd(T dd, T s, A alpha) {
+ return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_fwd(T s) {
+ return s * s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_bwd(T dd, T s) {
+ return dd * 2 * s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_fwd(T s) {
+ return s > 0 ? s : -s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_bwd(T dd, T s) {
+ return s > 0 ? dd : s < 0 ? -dd : 0;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_fwd(T s) {
+ return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_bwd(T dd, T s) {
+ return s > 0
+ ? (U)(dd / (2 * ::sqrtf((float)(s))))
+ : 0;
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U linear_fwd(T s, A alpha, A beta) {
+ return (U)(alpha * s + beta);
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U linear_bwd(T dd, T s, A alpha, A beta) {
+ (void) s;
+ (void) beta;
+ return (U)(dd * alpha);
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_fwd(T s, A alpha) {
+ s = s > 0 ? s : 0;
+ return s > alpha ? (U)(alpha) : s;
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_bwd(T dd, T s, A alpha) {
+ return dd * (0 < s && s < alpha ? 1 : 0);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_fwd(T s) {
+ float max_logf = 8.872284e+01; //::logf(FLT_MAX)
+ return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_bwd(T dd, T s) {
+ return (U)(dd / (1 + ::expf((float)(-s))));
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_fwd(T s) {
+ U v = (U)(::expf((float) -s));
+ return 1 / (1 + v);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_bwd(T dd, T s) {
+ U v = logistic_fwd<T, U>(s);
+ return dd * v * (1 - v);
+}
+
+inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
+ using namespace alg_kind;
+ using namespace utils;
+ const bool preserves_zero = true
+ && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
+ && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
+ return preserves_zero;
+}
+
+inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
+{
+ if (!bias)
+ return 0.0f;
+
+#define CASE(dt) \
+ case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
+
+ switch (data_type) {
+ CASE(data_type::s8);
+ CASE(data_type::u8);
+ CASE(data_type::s32);
+ CASE(data_type::f32);
+ default: assert(!"unimplemented");
+ }
+ return 0; // never happens (should probably be a NaN)
+#undef CASE
+}
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp
new file mode 100644
index 0000000000..cea849c96e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp
@@ -0,0 +1,238 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::data_type;
+
+namespace {
+bool memory_desc_sanity_check(int ndims,const dims_t dims,
+ data_type_t data_type, format_kind_t format_kind) {
+ if (ndims == 0) return true;
+
+ bool ok = true
+ && dims != nullptr
+ && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
+ && one_of(data_type, f32, s32, s8, u8)
+ && format_kind != format_kind::undef;
+ if (!ok) return false;
+ for (int d = 0; d < ndims; ++d)
+ if (dims[d] < 0) return false;
+
+ return true;
+}
+
+bool memory_desc_sanity_check(const memory_desc_t *md) {
+ if (md == nullptr) return false;
+ return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
+ format_kind::any);
+}
+}
+
+status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
+ const dims_t dims, data_type_t data_type, format_tag_t tag) {
+ if (any_null(memory_desc)) return invalid_arguments;
+ if (ndims == 0 || tag == format_tag::undef) {
+ *memory_desc = types::zero_md();
+ return success;
+ }
+
+ format_kind_t format_kind = types::format_tag_to_kind(tag);
+
+ /* memory_desc != 0 */
+ bool args_ok = !any_null(memory_desc)
+ && memory_desc_sanity_check(ndims, dims, data_type, format_kind);
+ if (!args_ok) return invalid_arguments;
+
+ auto md = memory_desc_t();
+ md.ndims = ndims;
+ array_copy(md.dims, dims, ndims);
+ md.data_type = data_type;
+ array_copy(md.padded_dims, dims, ndims);
+ md.format_kind = format_kind;
+
+ status_t status = success;
+ if (tag == format_tag::undef) {
+ status = invalid_arguments;
+ } else if (tag == format_tag::any) {
+ // nop
+ } else if (format_kind == format_kind::blocked) {
+ status = memory_desc_wrapper::compute_blocking(md, tag);
+ } else {
+ assert(!"unreachable");
+ status = invalid_arguments;
+ }
+
+ if (status == success)
+ *memory_desc = md;
+
+ return status;
+}
+
+status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
+ int ndims, const dims_t dims, data_type_t data_type,
+ const dims_t strides) {
+ if (any_null(memory_desc)) return invalid_arguments;
+ if (ndims == 0) {
+ *memory_desc = types::zero_md();
+ return success;
+ }
+
+ /* memory_desc != 0 */
+ bool args_ok = !any_null(memory_desc)
+ && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
+ if (!args_ok) return invalid_arguments;
+
+ auto md = memory_desc_t();
+ md.ndims = ndims;
+ array_copy(md.dims, dims, ndims);
+ md.data_type = data_type;
+ array_copy(md.padded_dims, dims, ndims);
+ md.format_kind = format_kind::blocked;
+
+ dims_t default_strides = {0};
+ if (strides == nullptr) {
+ default_strides[md.ndims - 1] = 1;
+ for (int d = md.ndims - 2; d >= 0; --d)
+ default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
+ strides = default_strides;
+ } else {
+ /* TODO: add sanity check for the provided strides */
+ }
+
+ array_copy(md.format_desc.blocking.strides, strides, md.ndims);
+
+ *memory_desc = md;
+
+ return status::success;
+}
+
+status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
+ const memory_desc_t *parent_md, const dims_t dims,
+ const dims_t offsets) {
+ if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
+ return invalid_arguments;
+
+ const memory_desc_wrapper src_d(parent_md);
+
+ for (int d = 0; d < src_d.ndims(); ++d) {
+ if (dims[d] < 0 || offsets[d] < 0
+ || (offsets[d] + dims[d] > src_d.dims()[d]))
+ return invalid_arguments;
+ }
+
+ if (src_d.format_kind() != format_kind::blocked)
+ return unimplemented;
+
+ dims_t blocks;
+ src_d.compute_blocks(blocks);
+
+ memory_desc_t dst_d = *parent_md;
+ auto &dst_d_blk = dst_d.format_desc.blocking;
+
+ /* TODO: put this into memory_desc_wrapper */
+ for (int d = 0; d < src_d.ndims(); ++d) {
+ /* very limited functionality for now */
+ const bool ok = true
+ && offsets[d] % blocks[d] == 0 /* [r1] */
+ && src_d.padded_offsets()[d] == 0
+ && (false
+ || dims[d] % blocks[d] == 0
+ || dims[d] < blocks[d]);
+ if (!ok)
+ return unimplemented;
+
+ const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
+
+ dst_d.dims[d] = dims[d];
+ dst_d.padded_dims[d] = is_right_border
+ ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
+ dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
+ dst_d.offset0 += /* [r1] */
+ offsets[d] / blocks[d] * dst_d_blk.strides[d];
+ }
+
+ *md = dst_d;
+
+ return success;
+}
+
+int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
+ const memory_desc_t *rhs) {
+ if (lhs == rhs) return 1;
+ if (any_null(lhs, rhs)) return 0;
+ return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
+}
+
+size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
+ if (md == nullptr) return 0;
+ return memory_desc_wrapper(*md).size();
+}
+
+status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
+ engine_t *engine, void *handle) {
+ if (any_null(memory, engine)) return invalid_arguments;
+ memory_desc_t z_md = types::zero_md();
+ return engine->memory_create(memory, md ? md : &z_md, handle);
+}
+
+status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
+ const memory_desc_t **md) {
+ if (any_null(memory, md)) return invalid_arguments;
+ *md = memory->md();
+ return success;
+}
+
+status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
+ if (any_null(memory, engine)) return invalid_arguments;
+ *engine = memory->engine();
+ return success;
+}
+
+status_t mkldnn_memory_get_data_handle(const memory_t *memory,
+ void **handle) {
+ if (any_null(handle))
+ return invalid_arguments;
+ if (memory == nullptr) {
+ *handle = nullptr;
+ return success;
+ }
+ return memory->get_data_handle(handle);
+}
+
+status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
+ if (any_null(memory)) return invalid_arguments;
+ return memory->set_data_handle(handle);
+}
+
+status_t mkldnn_memory_destroy(memory_t *memory) {
+ delete memory;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp
new file mode 100644
index 0000000000..03dfee01ff
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp
@@ -0,0 +1,63 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_HPP
+#define MEMORY_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+
+struct mkldnn_memory: public mkldnn::impl::c_compatible {
+ mkldnn_memory(mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::memory_desc_t *md)
+ : engine_(engine), md_(*md) {}
+ virtual ~mkldnn_memory() {}
+
+ /** allocates/initializes memory */
+ virtual mkldnn::impl::status_t init() = 0;
+
+ /** returns memory's engine */
+ mkldnn::impl::engine_t *engine() const { return engine_; }
+ /** returns memory's description */
+ const mkldnn::impl::memory_desc_t *md() const { return &md_; }
+
+ /** returns data handle */
+ virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
+
+ /** sets data handle */
+ virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
+
+ /** zeros padding */
+ virtual mkldnn::impl::status_t zero_pad() const
+ { return mkldnn::impl::status::success; }
+
+protected:
+ mkldnn::impl::engine_t *engine_;
+ const mkldnn::impl::memory_desc_t md_;
+
+private:
+ mkldnn_memory() = delete;
+ mkldnn_memory(const mkldnn_memory &) = delete;
+ mkldnn_memory(mkldnn_memory &&) = delete;
+ mkldnn_memory &operator=(const mkldnn_memory &) = delete;
+ mkldnn_memory &operator=(mkldnn_memory &&) = delete;
+};
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
new file mode 100644
index 0000000000..8a99be33f3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
@@ -0,0 +1,212 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include <initializer_list>
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t fill_blocked(memory_desc_t &md,
+ std::initializer_list<int> perm,
+ std::initializer_list<int> inner_blks,
+ std::initializer_list<int> inner_idxs) {
+ const bool ok = true
+ && perm.size() == (size_t)md.ndims
+ && inner_blks.size() == inner_idxs.size();
+ if (!ok) return status::invalid_arguments;
+
+ md.offset0 = 0;
+
+ blocking_desc_t &blk = md.format_desc.blocking;
+
+ dim_t block_size = 1;
+ dims_t blocks = {0};
+ utils::array_set(blocks, 1, md.ndims);
+
+ blk.inner_nblks = (int)inner_blks.size();
+
+ int iblk = 0;
+ for (const auto &b: inner_idxs)
+ blk.inner_idxs[iblk++] = b;
+
+ iblk = 0;
+ for (const auto &b: inner_blks) {
+ int dim = blk.inner_idxs[iblk];
+ block_size *= b;
+ blocks[dim] *= b;
+ blk.inner_blks[iblk++] = b;
+ }
+
+ utils::array_set(md.padded_offsets, 0, md.ndims);
+ for (int d = 0; d < md.ndims; ++d)
+ md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
+
+ dim_t stride = block_size;
+ // if only we use C++14, the initializer_list would have rbegin()/rend()...
+ for (int d = 0; d < md.ndims; ++d)
+ stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
+
+ for (const auto &d: perm) {
+ if (md.padded_dims[d] == 0) {
+ blk.strides[d] = 1;
+ continue;
+ }
+ stride /= md.padded_dims[d] / blocks[d];
+ blk.strides[d] = stride;
+ }
+
+ assert(stride == block_size);
+
+ return status::success;
+}
+
+status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
+ format_tag_t tag)
+{
+ using namespace format_tag;
+
+ if (memory_desc.ndims == 0) return status::invalid_arguments;
+
+# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
+ case tag: return fill_blocked(memory_desc, __VA_ARGS__)
+
+ switch (tag) {
+ C(a, {0}, {}, {});
+ C(ab, {0, 1}, {}, {});
+ C(abc, {0, 1, 2}, {}, {});
+ C(abcd, {0, 1, 2, 3}, {}, {});
+ C(abcde, {0, 1, 2, 3, 4}, {}, {});
+ C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
+ C(abdec, {0, 1, 3, 4, 2}, {}, {});
+ C(acb, {0, 2, 1}, {}, {});
+ C(acbde, {0, 2, 1, 3, 4}, {}, {});
+ C(acdb, {0, 2, 3, 1}, {}, {});
+ C(acdeb, {0, 2, 3, 4, 1}, {}, {});
+ C(ba, {1, 0}, {}, {});
+ C(bac, {1, 0, 2}, {}, {});
+ C(bacd, {1, 0, 2, 3}, {}, {});
+ C(bcda, {1, 2, 3, 0}, {}, {});
+ C(cba, {2, 1, 0}, {}, {});
+ C(cdba, {2, 3, 1, 0}, {}, {});
+ C(cdeba, {2, 3, 4, 1, 0}, {}, {});
+ C(decab, {3, 4, 2, 0, 1}, {}, {});
+
+ C(Abc4a, {0, 1, 2}, {4}, {0});
+ C(aBc4b, {0, 1, 2}, {4}, {1});
+ C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
+ C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
+ C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
+ C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
+ C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
+ C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
+ C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
+ C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
+ C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
+ C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
+ C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
+ C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
+ C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
+ C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
+ C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
+ C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
+ C(Acb4a, {0, 2, 1}, {4}, {0});
+ C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
+ C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
+
+ C(Abc16a, {0, 1, 2}, {16}, {0});
+ C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
+ C(aBc16b, {0, 1, 2}, {16}, {1});
+ C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
+ C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
+ C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
+ C(aBc8b, {0, 1, 2}, {8}, {1});
+ C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
+ C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
+ C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
+ C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
+ C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
+ C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
+ C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
+ C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
+ C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
+ C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
+ C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
+ C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
+ C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
+ C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
+ C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
+ C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
+ C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
+ C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
+ C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
+ C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
+ C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
+ C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
+ C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
+ C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
+ C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
+ C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
+ C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
+ C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
+ C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
+ C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
+ C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
+ C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
+ C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
+ C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
+ C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
+ C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
+ C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
+ C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
+ C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
+ C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
+ C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
+ C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
+ C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
+ C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
+ C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
+ C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
+ C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
+ C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
+ C(Acb16a, {0, 2, 1}, {16}, {0});
+ C(Acb8a, {0, 2, 1}, {8}, {0});
+ C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
+ C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
+ C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
+ C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
+ C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
+ C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
+ C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
+ C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
+ default: break;
+ }
+
+#undef C
+
+ return status::invalid_arguments;
+}
+
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
new file mode 100644
index 0000000000..1758f9078a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
@@ -0,0 +1,400 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_DESC_WRAPPER_HPP
+#define MEMORY_DESC_WRAPPER_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+/** thin wrapper class over \struct memory_desc_t which allows easy
+ * manipulations with underlying C structure, which is taken by reference */
+struct memory_desc_wrapper: public c_compatible {
+ const memory_desc_t *md_;
+
+ /** constructor which takes a reference to a constant underlying C memory
+ * descriptor \param md */
+ memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
+ memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
+
+ /* implementing attributes */
+ int ndims() const { return md_->ndims; }
+ const dims_t &dims() const { return md_->dims; }
+ data_type_t data_type() const { return md_->data_type; }
+
+ const dims_t &padded_dims() const { return md_->padded_dims; }
+ const dims_t &padded_offsets() const { return md_->padded_offsets; }
+ dim_t offset0() const { return md_->offset0; }
+
+ format_kind_t format_kind() const { return md_->format_kind; }
+
+ bool is_blocking_desc() const
+ { return format_kind() == format_kind::blocked; }
+ bool is_wino_desc() const
+ { return format_kind() == format_kind::wino; }
+ bool is_rnn_packed_desc() const
+ { return format_kind() == format_kind::rnn_packed; }
+
+ const blocking_desc_t &blocking_desc() const {
+ assert(is_blocking_desc());
+ return md_->format_desc.blocking;
+ }
+ const wino_desc_t &wino_desc() const {
+ assert(is_wino_desc());
+ return md_->format_desc.wino_desc;
+ }
+ const rnn_packed_desc_t &rnn_packed_desc() const {
+ assert(is_rnn_packed_desc());
+ return md_->format_desc.rnn_packed_desc;
+ }
+
+ const memory_extra_desc_t &extra() const { return md_->extra; }
+
+ /* some useful function */
+
+ /** returns the number of elements including padding if \param with_padding
+ * is true, and the number of data elements otherwise */
+ dim_t nelems(bool with_padding = false) const {
+ if (is_zero()) return 0;
+ return utils::array_product(
+ with_padding ? padded_dims() : dims(), ndims());
+ }
+
+ /** returns true if memory descriptor is zero */
+ bool is_zero() const { return ndims() == 0; }
+
+ /** returns true if memory descriptor contains zero as one of its dim */
+ bool has_zero_dim() const { return nelems() == 0; }
+
+ /** return the size of data type (a shortcut) */
+ size_t data_type_size() const
+ { return types::data_type_size(data_type()); }
+
+ /** return the size of data type of additional buffer */
+ size_t additional_buffer_data_size() const {
+ if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
+ return sizeof(int32_t);
+ return 0;
+ }
+
+ /** return true if memory format has additional buffer */
+ bool is_additional_buffer() const {
+ return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
+ }
+
+ /** returns the size of additional buffer */
+ size_t additional_buffer_size() const {
+ if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
+ int cmask = extra().compensation_mask;
+ assert(cmask == 1 || cmask == 3);
+ dim_t prod = 1;
+ for (int d = 0; d < ndims(); ++d)
+ if (cmask & (1<<d)) prod *= padded_dims()[d];
+ return prod * additional_buffer_data_size();
+ }
+
+ return 0;
+ }
+
+ /** returns the size required to store described memory
+ * note: if offset0 != 0 returns 0 (need to specify the behavior) */
+ size_t size() const {
+ if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
+ return 0;
+
+ if (format_kind() == format_kind::wino) {
+ return wino_desc().size;
+ } else if (format_kind() == format_kind::rnn_packed) {
+ return rnn_packed_desc().size;
+ } else {
+ if (offset0() != 0) return 0;
+
+ dims_t blocks = {0};
+ compute_blocks(blocks);
+
+ const auto &bd = blocking_desc();
+
+ size_t max_size = 0;
+ for (int d = 0; d < ndims(); ++d)
+ max_size = nstl::max<size_t>(max_size,
+ padded_dims()[d] / blocks[d] * bd.strides[d]);
+
+ if (max_size == 1 && bd.inner_nblks != 0) {
+ max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
+ }
+
+ return max_size * data_type_size() + additional_buffer_size();
+ }
+ }
+
+ /** returns true if data is dense in memory */
+ bool is_dense(bool with_padding = false) const {
+ if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
+ return false;
+ return nelems(with_padding) * data_type_size() == size();
+ }
+
+ /** returns true if memory desc is fully defined */
+ bool is_defined() const { return format_kind() != format_kind::any; }
+
+ /** returns true if the only (potentially) padded dim is \param dim */
+ bool only_padded_dim(int dim) const {
+ for (int d = 0; d < ndims(); ++d)
+ if (d != dim && dims()[d] != padded_dims()[d])
+ return false;
+ return true;
+ }
+
+ /** returns true if memory desc has blocked layout and block dims are 1s */
+ bool is_plain() const {
+ if (!is_blocking_desc()) return false;
+ return blocking_desc().inner_nblks == 0;
+ }
+
+ /** returns overall block sizes */
+ void compute_blocks(dims_t blocks) const {
+ if (!is_blocking_desc()) {
+ utils::array_set(blocks, 0, ndims());
+ return;
+ }
+
+ utils::array_set(blocks, 1, ndims());
+
+ const auto &bd = blocking_desc();
+ for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
+ blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
+ }
+
+ /* comparison section */
+
+ bool operator==(const memory_desc_wrapper &rhs) const
+ { return *this->md_ == *rhs.md_; }
+ bool operator!=(const memory_desc_wrapper &rhs) const
+ { return !operator==(rhs); }
+ bool operator==(const memory_desc_t &rhs) const
+ { return operator==(memory_desc_wrapper(rhs)); }
+ bool operator!=(const memory_desc_t &rhs) const
+ { return !operator==(rhs); }
+
+ /** returns true if data (w/o padding if with_padding == false and w/
+ * padding otherwise) have the same physical structure, i.e. dimensions,
+ * strides, and blocked structure. Depending on with_data_type flag
+ * data_type is taken or not taken into account. dim_start allows to check
+ * similarity for the logical part of data [dim_start .. ndims()].
+ * CAUTION: format kind any and undef are not similar to whatever, hence the
+ * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
+ /* TODO: revise */
+ bool similar_to(const memory_desc_wrapper &rhs,
+ bool with_padding = true, bool with_data_type = true,
+ int dim_start = 0) const;
+
+ /** returns true if one memory can be reordered to another */
+ bool consistent_with(const memory_desc_wrapper &rhs) const;
+
+ /** returns true if the memory desc corresponds to the given format tag and
+ * strides.
+ * @sa memory_desc_matches_tag */
+ bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
+ return memory_desc_matches_tag(*md_, tag, strides);
+ }
+
+ /** returns matching tag (or undef if match is not found)
+ * XXX: This is a workaround that eventually should go away! */
+ template <typename... Tags>
+ format_tag_t matches_one_of_tag(Tags ...tags) const {
+ for (const auto tag: {tags...}) {
+ if (memory_desc_matches_tag(*md_, tag))
+ return tag;
+ }
+ return format_tag::undef;
+ }
+
+ /* offset section */
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * an array \param pos. if \param is_pos_padded is true \param pos
+ * represents the position in already padded area */
+ dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
+ assert(is_blocking_desc());
+ const blocking_desc_t &blk = blocking_desc();
+
+ dims_t pos_copy = {0};
+ for (int d = 0; d < ndims(); ++d)
+ pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
+
+ dim_t phys_offset = offset0();
+
+ if (blk.inner_nblks > 0) {
+ dim_t blk_stride = 1;
+ for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
+ const int d = blk.inner_idxs[iblk];
+ const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
+
+ phys_offset += p * blk_stride;
+
+ pos_copy[d] /= blk.inner_blks[iblk];
+
+ blk_stride *= blk.inner_blks[iblk];
+ }
+ }
+
+ for (int d = 0; d < ndims(); ++d) {
+ const dim_t p = pos_copy[d];
+ phys_offset += p * blk.strides[d];
+ }
+
+ return phys_offset;
+ }
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * a scalar \param l_offset. if \param is_pos_padded is true, \param
+ * l_offset represents logical offset in already padded area */
+ dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
+ assert(is_blocking_desc());
+ dims_t pos;
+ for (int rd = 0; rd < ndims(); ++rd) {
+ const int d = ndims() - 1 - rd;
+ const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
+ pos[d] = l_offset % cur_dim;
+ l_offset /= cur_dim;
+ }
+ return off_v(pos, is_pos_padded);
+ }
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * a tuple of indices (\param xn, ..., \param x1, \param x0) */
+ template<typename... Args>
+ dim_t off(Args... args) const {
+ assert(sizeof...(args) == ndims());
+ dims_t pos = { args... };
+ return off_v(pos, false);
+ }
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * a tuple of indices (\param xn, ..., \param x1, \param x0) in already
+ * padded area */
+ template<typename... Args>
+ dim_t off_padding(Args... args) const {
+ assert(sizeof...(args) == ndims());
+ dims_t pos = { args... };
+ return off_v(pos, true);
+ }
+
+ /** returns physical offset by logical one. Logical offset is represented by
+ * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
+ * user responsibility to adjust the result to get offset within blocks */
+ template<typename ...Args>
+ dim_t blk_off(Args... args) const {
+ return _blk_off<sizeof...(args), Args...>(args...);
+ }
+
+ template<bool skip_first, typename T, typename ...Args>
+ dim_t blk_off(T xn, Args... args) const {
+ return skip_first
+ ? blk_off<Args...>(args...)
+ : blk_off<T, Args...>(xn, args...);
+ }
+
+ /* static functions section */
+ /* TODO: replace with non-static, once md_ becomes non-const ref */
+
+ static status_t compute_blocking(memory_desc_t &memory_desc,
+ format_tag_t tag);
+
+private:
+ /* TODO: put logical_offset in utils */
+ template<typename T>
+ dim_t logical_offset(T x0) const { return x0; }
+
+ template<typename T, typename... Args>
+ dim_t logical_offset(T xn, Args... args) const {
+ const size_t n_args = sizeof...(args);
+ return xn * utils::array_product<n_args>(
+ &dims()[ndims() - n_args]) + logical_offset(args...);
+ }
+
+ template<int ORIG_LEN, typename ...Void>
+ dim_t _blk_off() const { return offset0(); }
+
+ template<int ORIG_LEN, typename T, typename ...Args>
+ dim_t _blk_off(T xc, Args ...args) const {
+ assert(is_blocking_desc());
+ constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
+ return xc * blocking_desc().strides[dc]
+ + _blk_off<ORIG_LEN, Args...>(args...);
+ }
+};
+
+inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
+ bool with_padding, bool with_data_type, int dim_start) const {
+ using namespace utils;
+
+ if (one_of(format_kind(), format_kind::undef, format_kind::any))
+ return false;
+ if (is_wino_desc() || is_rnn_packed_desc())
+ return false;
+
+ const int ds = dim_start;
+ const auto &blk = blocking_desc();
+ const auto &r_blk = rhs.blocking_desc();
+
+ return ndims() == rhs.ndims()
+ && dim_start <= ndims() /* guard */
+ && format_kind() == rhs.format_kind()
+ && IMPLICATION(with_data_type, data_type() == rhs.data_type())
+ && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
+ && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
+ && blk.inner_nblks == r_blk.inner_nblks
+ && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
+ && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
+ && IMPLICATION(with_padding, true
+ && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
+ ndims() - ds)
+ && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
+ ndims() - ds));
+}
+
+inline bool memory_desc_wrapper::consistent_with(
+ const memory_desc_wrapper &rhs) const {
+ if (ndims() == rhs.ndims()) {
+ for (int d = 0; d < ndims(); ++d) {
+ if (dims()[d] != rhs.dims()[d]) return false;
+ }
+ return true;
+ } else {
+ /* TODO: revise.
+ * is the following possible?
+ * [1, a, b] <--reorder--> [a, b]
+ * [a, 1, b] <--reorder--> [a, b]
+ * not, at least for now */
+ return false;
+ }
+}
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
new file mode 100644
index 0000000000..ec077b308c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
@@ -0,0 +1,295 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_TRACKING_HPP
+#define MEMORY_TRACKING_HPP
+
+#include <assert.h>
+#include <unordered_map>
+
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace memory_tracking {
+
+/* Memory tracking capabilities
+ *
+ * The main purpose of this header file is to provide uniform way to register
+ * required memory for a scratchpad at a primitive descriptor creation time
+ * and then easily access it having only the base address of the scratchpad.
+ *
+ * Primitives might contain multiple disjoint parts that require temporary
+ * buffers (known as scratchpad) during their execution. A primitive descriptor
+ * should summarize all the needs into one single number -- the buffer size
+ * that would be requested from a user. At execution time, the corresponding
+ * primitive will receive a base pointer to a scratchpad. It then needs to
+ * provide each part of algorithm the corresponding piece of memory. Three main
+ * challenges here are:
+ * 1. Track correct offset (from the base scratchpad address) for each piece
+ * 2. Algorithm might require that different memory pieces to be aligned, so
+ * the scratchpad size is no more just a sum of size of the corresponding
+ * subparts.
+ * 3. While a primitive is responsible for its scratchpad, the implementation
+ * might use some other basic blocks (e.g. cpu_reducer) that also require
+ * scratchpad memory. So there should be a simple way of passing the
+ * information back and force between the main algorithm (a primitive) and
+ * auxiliary stuff that lives completely separately from it (e.g. reducer).
+ *
+ * To address these challenges this header file provides 3 structures:
+ * 1. registry_t -- the class the stores the information about requested
+ * memory. The information includes required size and desired
+ * alignment for each piece. This class is also responsible
+ * for computing the right offset to a given piece using the
+ * base pointer.
+ * This class is basically a ledger with all entries.
+ * Lives in primitive descriptors.
+ *
+ * 2. registrar_t -- the interface to a registry_t to book memory. Used at
+ * primitive descriptor creation time only. Contains a
+ * reference to the corresponding *mutable* registry.
+ * Always modifiable.
+ * Allows chaining (using prefixes).
+ *
+ * 3. grantor_t -- the interface to a registry_t to access memory. Used at
+ * primitive execution time only. Contains a reference to
+ * the corresponding *constant* registry and base pointer.
+ * Always constant.
+ * Allows chaining (using prefixes).
+ *
+ * Both registrar_t and grantor_t allow chaining with extra prefix provided.
+ * The feature is useful when a primitive offload a part of computations to
+ * some other primitives which require their own scratchpad space
+ * (e.g. reducer). Prefixes are used to avoid key collision in cases when
+ * multiple sub-primitive (e.g. multiple reducers) are used.
+ *
+ * A short example below demonstrates how to use aforementioned classes. In it
+ * the main primitive is convolution that uses scratchpad for keeping padded
+ * bias. It also needs a reducer, that needs its own space as well.
+ *
+ * ``` c++
+ * struct reducer_t {
+ * static void init(registrar_t &scratchpad) {
+ * // preserve space for the reduction (one page aligned)
+ * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
+ * }
+ *
+ * void exec(const grantor_t &scratchpad) {
+ * // get the pointer to preserved space. scratchpad came from
+ * // upper primitive (convolution in this example)
+ * auto space = scratchpad.get<float>(key_reducer_space);
+ *
+ * space[:] += ...;
+ * }
+ * };
+ *
+ * struct conv_t {
+ * struct pd_t {
+ * void init() {
+ * registrar_t scratchpad(scratchpad_registry_);
+ *
+ * // preserve a space for padded bias (using default alignment)
+ * scratchpad.book(key_conv_padded_bias, 128);
+ *
+ * // create a proxy registrar for the reducer All entries made
+ * // by reducer would live in convolution's registry, but would
+ * // have their own `prefix`, so no interference with conv's
+ * // buffers.
+ * registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
+ *
+ * reducer_t::init(reducer_scratchpad);
+ * }
+ *
+ * registry_t scratchpad_registry_;
+ * }
+ *
+ * void exec() {
+ * // get the base pointer to a scratchpad memory from a user
+ * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
+ *
+ * // create a grantor to the scratchpad (and provide the base
+ * // pointer).
+ * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
+ *
+ * // access the padded_bias (need only key name and the grantor)
+ * auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ *
+ * // to give the `right` grantor to reducer we need to add the
+ * // corresponding prefix, so that reducer would be able to access
+ * // its keys. The call is very similar to the one in pd_t::init
+ * // with only difference in types: grantor_t vs registrar_t.
+ * grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
+ * reducer->exec(reducer_scratchpad);
+ * }
+ * };
+ * ```
+ */
+
+
+/* namespace with common keys and prefixes */
+namespace names {
+enum {
+ key_none = 0,
+ key_bnorm_tmp_mean,
+ key_bnorm_tmp_var,
+ key_bnorm_tmp_diff_ss,
+ key_bnorm_tmp_stats,
+ key_bnorm_reduction,
+ key_concat_iptrs,
+ key_concat_istrides,
+ key_concat_nelems,
+ key_concat_optrs,
+ key_conv_adjusted_scales,
+ key_conv_bia_reduction,
+ key_conv_gemm_col,
+ key_conv_gemm_imtr,
+ key_conv_int_dat_in_acc_dt,
+ key_conv_padded_bias,
+ key_conv_rtus_space,
+ key_conv_tr_diff_dst,
+ key_conv_tr_diff_dst_bctx,
+ key_conv_tr_src,
+ key_conv_tr_src_bctx,
+ key_conv_wei_reduction,
+ key_conv_wei_bia_reduction,
+ key_conv_wei_bia_reduction_bctx,
+ key_iprod_int_dat_in_acc_dt,
+ key_reducer_space,
+ key_reducer_space_bctx,
+ key_reorder_wino_plain,
+ key_reorder_wino_transform_space,
+ key_reorder_rnn_weights_quantization,
+ key_reorder_rnn_weights_reduction,
+ key_rnn_space,
+ key_rnn_ptrs_bia,
+ key_rnn_ptrs_wei_layer,
+ key_rnn_ptrs_wei_iter,
+ key_softmax_reduction,
+ key_wino_U,
+ key_wino_V,
+ key_wino_M,
+ key_barrier,
+};
+
+enum {
+ prefix_none = 0,
+ prefix_reducer_bia,
+ prefix_reducer_wei,
+};
+}
+
+// level 0: 00 00 00 xxx
+// level 1: 00 00 aa xxx
+// level 2: 00 aa bb xxx
+// level 3: aa bb cc xxx
+// max # of levels: 3 + 1 (base_level)
+// here:
+// xxx : [1 .. MAX_KEY) : key
+// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
+
+using key_t = uint32_t;
+enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
+
+/// generates global key based on a prefix and a local key
+inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
+
+/// generates global prefix based on the global parent and the local ones
+inline key_t make_prefix(key_t parent_prefix, key_t prefix)
+{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
+
+struct registrar_t;
+struct grantor_t;
+
+struct registry_t {
+ void book(const key_t &key, size_t size, size_t alignment) {
+ if (size == 0) return;
+ assert(offset_map_.count(key) == 0);
+
+ size = utils::rnd_up(size, minimal_alignment);
+ alignment = nstl::max<size_t>(alignment, minimal_alignment);
+ offset_map_[key] = entry_t{size_, size, alignment};
+
+ size_ += size + alignment - minimal_alignment;
+ }
+
+ void *get(const key_t &key, void *base_ptr) const {
+ if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
+ if (offset_map_.count(key) != 1) return nullptr;
+
+ const auto &e = offset_map_.at(key);
+ base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
+ char *ptr = (char *)base_ptr + e.offset;
+ return utils::align_ptr<void>(ptr, e.alignment);
+ }
+
+ size_t size() const
+ { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
+
+ registrar_t registrar();
+ grantor_t grantor(void *base_ptr) const;
+
+protected:
+ enum { minimal_alignment = 64 };
+ struct entry_t { size_t offset, size, alignment; };
+
+ std::unordered_map<key_t, entry_t> offset_map_;
+ size_t size_ = 0;
+};
+
+struct registrar_t {
+ enum { default_alignment = 64 };
+
+ registrar_t(registry_t &registry): registry_(registry), prefix_(0) {}
+ registrar_t(registrar_t &parent, const key_t &prefix)
+ : registry_(parent.registry_)
+ , prefix_(make_prefix(parent.prefix_, prefix)) {}
+
+ void book(const key_t &key, size_t size,
+ size_t alignment = default_alignment)
+ { registry_.book(make_key(prefix_, key), size, alignment); }
+
+protected:
+ registry_t &registry_;
+ const key_t prefix_;
+};
+
+struct grantor_t {
+ grantor_t(const registry_t &registry, void *base_ptr)
+ : registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
+ grantor_t(const grantor_t &parent, const key_t &prefix)
+ : registry_(parent.registry_)
+ , prefix_(make_prefix(parent.prefix_, prefix))
+ , base_ptr_(parent.base_ptr_) {}
+
+ template <typename T = void> T *get(const key_t &key) const
+ { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
+
+protected:
+ const registry_t &registry_;
+ const key_t prefix_;
+ void *base_ptr_;
+};
+
+inline registrar_t registry_t::registrar() { return registrar_t(*this); }
+inline grantor_t registry_t::grantor(void *base_ptr) const
+{ return grantor_t(*this, base_ptr); }
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
new file mode 100644
index 0000000000..2ef4a8fddc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
@@ -0,0 +1,131 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <stdio.h>
+#include <cinttypes>
+
+#include "mkldnn_debug.h"
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#define DPRINT(...) do { \
+ int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
+ if (l < 0) return l; \
+ if ((size_t)l >= str_len) return -1; \
+ written_len += l; str_len -= l; \
+} while(0)
+
+int mkldnn_md2fmt_str(char *str, size_t str_len,
+ const mkldnn_memory_desc_t *mdesc) {
+ using namespace mkldnn::impl;
+
+ if (str == nullptr || str_len <= 1u)
+ return -1;
+
+ int written_len = 0;
+
+ if (mdesc == nullptr) {
+ DPRINT("%s::%s::",
+ mkldnn_dt2str(data_type::undef),
+ mkldnn_fmt_kind2str(format_kind::undef));
+ return written_len;
+ }
+
+ memory_desc_wrapper md(mdesc);
+
+ DPRINT("%s:", mkldnn_dt2str(md.data_type()));
+
+ bool padded_dims = false, padded_offsets = false;
+ for (int d = 0; d < md.ndims(); ++d) {
+ if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
+ if (md.padded_offsets()[d] != 0) padded_offsets = true;
+ }
+ bool offset0 = md.offset0();
+ DPRINT("%s%s%s:",
+ padded_dims ? "p" : "",
+ padded_offsets ? "o" : "",
+ offset0 ? "0" : "");
+
+ DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
+
+ if (!md.is_blocking_desc()) {
+ /* TODO: extend */
+ DPRINT("%s:", "");
+ } else {
+ const auto &blk = md.blocking_desc();
+
+ dims_t blocks;
+ md.compute_blocks(blocks);
+
+ char dim_chars[MKLDNN_MAX_NDIMS + 1];
+
+ bool plain = true;
+ for (int d = 0; d < md.ndims(); ++d) {
+ dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
+ if (blocks[d] != 1) plain = false;
+ }
+
+ dims_t strides;
+ utils::array_copy(strides, blk.strides, md.ndims());
+ utils::simultaneous_sort(strides, dim_chars, md.ndims(),
+ [](dim_t a, dim_t b) { return b - a; });
+
+ dim_chars[md.ndims()] = '\0';
+ DPRINT("%s", dim_chars);
+
+ if (!plain) {
+ for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
+ DPRINT("%d%c", (int)blk.inner_blks[iblk],
+ 'a' + (char)blk.inner_idxs[iblk]);
+ }
+ }
+
+ DPRINT("%s", ":");
+ }
+
+ DPRINT("f%lx", (long)md.extra().flags);
+
+ return written_len;
+}
+
+int mkldnn_md2dim_str(char *str, size_t str_len,
+ const mkldnn_memory_desc_t *mdesc) {
+ using namespace mkldnn::impl;
+
+ if (str == nullptr || str_len <= 1)
+ return -1;
+
+ int written_len = 0;
+
+ if (mdesc == nullptr || mdesc->ndims == 0) {
+ DPRINT("%s", "");
+ return written_len;
+ }
+
+ memory_desc_wrapper md(mdesc);
+
+ for (int d = 0; d < md.ndims() - 1; ++d)
+ DPRINT("%" PRId64 "x", md.dims()[d]);
+ DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
+
+ return written_len;
+}
+
+#undef DPRINT
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
new file mode 100644
index 0000000000..16a8f7ea5e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
@@ -0,0 +1,365 @@
+/*******************************************************************************
+* Copyright 2018-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/* DO NOT EDIT, AUTO-GENERATED */
+
+#include <assert.h>
+
+#include "mkldnn_debug.h"
+#include "mkldnn_types.h"
+
+const char *mkldnn_status2str(mkldnn_status_t v) {
+ if (v == mkldnn_success) return "success";
+ if (v == mkldnn_out_of_memory) return "out_of_memory";
+ if (v == mkldnn_try_again) return "try_again";
+ if (v == mkldnn_invalid_arguments) return "invalid_arguments";
+ if (v == mkldnn_not_ready) return "not_ready";
+ if (v == mkldnn_unimplemented) return "unimplemented";
+ if (v == mkldnn_iterator_ends) return "iterator_ends";
+ if (v == mkldnn_runtime_error) return "runtime_error";
+ if (v == mkldnn_not_required) return "not_required";
+ assert(!"unknown status");
+ return "unknown status";
+}
+
+const char *mkldnn_dt2str(mkldnn_data_type_t v) {
+ if (v == mkldnn_data_type_undef) return "undef";
+ if (v == mkldnn_f32) return "f32";
+ if (v == mkldnn_s32) return "s32";
+ if (v == mkldnn_s8) return "s8";
+ if (v == mkldnn_u8) return "u8";
+ assert(!"unknown dt");
+ return "unknown dt";
+}
+
+const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
+ if (v == mkldnn_format_kind_undef) return "undef";
+ if (v == mkldnn_format_kind_any) return "any";
+ if (v == mkldnn_blocked) return "blocked";
+ if (v == mkldnn_format_kind_wino) return "wino";
+ if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
+ assert(!"unknown fmt_kind");
+ return "unknown fmt_kind";
+}
+
+const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
+ if (v == mkldnn_format_tag_undef) return "undef";
+ if (v == mkldnn_format_tag_any) return "format_tag_any";
+ if (v == mkldnn_a) return "a";
+ if (v == mkldnn_ab) return "ab";
+ if (v == mkldnn_abc) return "abc";
+ if (v == mkldnn_abcd) return "abcd";
+ if (v == mkldnn_abcde) return "abcde";
+ if (v == mkldnn_abcdef) return "abcdef";
+ if (v == mkldnn_abdec) return "abdec";
+ if (v == mkldnn_acb) return "acb";
+ if (v == mkldnn_acbde) return "acbde";
+ if (v == mkldnn_acdb) return "acdb";
+ if (v == mkldnn_acdeb) return "acdeb";
+ if (v == mkldnn_ba) return "ba";
+ if (v == mkldnn_bac) return "bac";
+ if (v == mkldnn_bacd) return "bacd";
+ if (v == mkldnn_bcda) return "bcda";
+ if (v == mkldnn_cba) return "cba";
+ if (v == mkldnn_cdba) return "cdba";
+ if (v == mkldnn_cdeba) return "cdeba";
+ if (v == mkldnn_decab) return "decab";
+ if (v == mkldnn_Abc16a) return "Abc16a";
+ if (v == mkldnn_ABc16a16b) return "ABc16a16b";
+ if (v == mkldnn_aBc16b) return "aBc16b";
+ if (v == mkldnn_ABc16b16a) return "ABc16b16a";
+ if (v == mkldnn_Abc4a) return "Abc4a";
+ if (v == mkldnn_aBc4b) return "aBc4b";
+ if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
+ if (v == mkldnn_ABc4b4a) return "ABc4b4a";
+ if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
+ if (v == mkldnn_ABc8a8b) return "ABc8a8b";
+ if (v == mkldnn_aBc8b) return "aBc8b";
+ if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
+ if (v == mkldnn_ABc8b8a) return "ABc8b8a";
+ if (v == mkldnn_Abcd16a) return "Abcd16a";
+ if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
+ if (v == mkldnn_aBcd16b) return "aBcd16b";
+ if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
+ if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
+ if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
+ if (v == mkldnn_Abcd4a) return "Abcd4a";
+ if (v == mkldnn_aBcd4b) return "aBcd4b";
+ if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
+ if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
+ if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
+ if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
+ if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
+ if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
+ if (v == mkldnn_aBcd8b) return "aBcd8b";
+ if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
+ if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
+ if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
+ if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
+ if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
+ if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
+ if (v == mkldnn_Abcde16a) return "Abcde16a";
+ if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
+ if (v == mkldnn_aBcde16b) return "aBcde16b";
+ if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
+ if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
+ if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
+ if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
+ if (v == mkldnn_Abcde4a) return "Abcde4a";
+ if (v == mkldnn_aBcde4b) return "aBcde4b";
+ if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
+ if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
+ if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
+ if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
+ if (v == mkldnn_Abcde8a) return "Abcde8a";
+ if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
+ if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
+ if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
+ if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
+ if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
+ if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
+ if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
+ if (v == mkldnn_aBcdef16b) return "aBcdef16b";
+ if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
+ if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
+ if (v == mkldnn_aBcdef4b) return "aBcdef4b";
+ if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
+ if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
+ if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
+ if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
+ if (v == mkldnn_aBdc16b) return "aBdc16b";
+ if (v == mkldnn_aBdc4b) return "aBdc4b";
+ if (v == mkldnn_aBdc8b) return "aBdc8b";
+ if (v == mkldnn_aBdec16b) return "aBdec16b";
+ if (v == mkldnn_aBdec4b) return "aBdec4b";
+ if (v == mkldnn_aBdec8b) return "aBdec8b";
+ if (v == mkldnn_aBdefc16b) return "aBdefc16b";
+ if (v == mkldnn_aBdefc4b) return "aBdefc4b";
+ if (v == mkldnn_aBdefc8b) return "aBdefc8b";
+ if (v == mkldnn_Acb16a) return "Acb16a";
+ if (v == mkldnn_Acb4a) return "Acb4a";
+ if (v == mkldnn_Acb8a) return "Acb8a";
+ if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
+ if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
+ if (v == mkldnn_Acdb16a) return "Acdb16a";
+ if (v == mkldnn_Acdb4a) return "Acdb4a";
+ if (v == mkldnn_Acdb8a) return "Acdb8a";
+ if (v == mkldnn_Acdeb16a) return "Acdeb16a";
+ if (v == mkldnn_Acdeb4a) return "Acdeb4a";
+ if (v == mkldnn_Acdeb8a) return "Acdeb8a";
+ if (v == mkldnn_BAc16a16b) return "BAc16a16b";
+ if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
+ if (v == mkldnn_format_tag_last) return "format_tag_last";
+ if (v == mkldnn_x) return "x";
+ if (v == mkldnn_nc) return "nc";
+ if (v == mkldnn_cn) return "cn";
+ if (v == mkldnn_ncw) return "ncw";
+ if (v == mkldnn_nwc) return "nwc";
+ if (v == mkldnn_nchw) return "nchw";
+ if (v == mkldnn_nhwc) return "nhwc";
+ if (v == mkldnn_chwn) return "chwn";
+ if (v == mkldnn_ncdhw) return "ncdhw";
+ if (v == mkldnn_ndhwc) return "ndhwc";
+ if (v == mkldnn_oi) return "oi";
+ if (v == mkldnn_io) return "io";
+ if (v == mkldnn_oiw) return "oiw";
+ if (v == mkldnn_wio) return "wio";
+ if (v == mkldnn_oihw) return "oihw";
+ if (v == mkldnn_hwio) return "hwio";
+ if (v == mkldnn_ihwo) return "ihwo";
+ if (v == mkldnn_iohw) return "iohw";
+ if (v == mkldnn_oidhw) return "oidhw";
+ if (v == mkldnn_dhwio) return "dhwio";
+ if (v == mkldnn_goiw) return "goiw";
+ if (v == mkldnn_goihw) return "goihw";
+ if (v == mkldnn_hwigo) return "hwigo";
+ if (v == mkldnn_giohw) return "giohw";
+ if (v == mkldnn_goidhw) return "goidhw";
+ if (v == mkldnn_tnc) return "tnc";
+ if (v == mkldnn_ntc) return "ntc";
+ if (v == mkldnn_ldsnc) return "ldsnc";
+ if (v == mkldnn_ldigo) return "ldigo";
+ if (v == mkldnn_ldgoi) return "ldgoi";
+ if (v == mkldnn_ldgo) return "ldgo";
+ if (v == mkldnn_nCdhw16c) return "nCdhw16c";
+ if (v == mkldnn_nCdhw4c) return "nCdhw4c";
+ if (v == mkldnn_nCdhw8c) return "nCdhw8c";
+ if (v == mkldnn_nChw16c) return "nChw16c";
+ if (v == mkldnn_nChw4c) return "nChw4c";
+ if (v == mkldnn_nChw8c) return "nChw8c";
+ if (v == mkldnn_nCw16c) return "nCw16c";
+ if (v == mkldnn_nCw4c) return "nCw4c";
+ if (v == mkldnn_nCw8c) return "nCw8c";
+ if (v == mkldnn_IOw16o16i) return "IOw16o16i";
+ if (v == mkldnn_OIw16i16o) return "OIw16i16o";
+ if (v == mkldnn_OIw16o16i) return "OIw16o16i";
+ if (v == mkldnn_Oiw16o) return "Oiw16o";
+ if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
+ if (v == mkldnn_OIw4i4o) return "OIw4i4o";
+ if (v == mkldnn_Oiw4o) return "Oiw4o";
+ if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
+ if (v == mkldnn_OIw8i8o) return "OIw8i8o";
+ if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
+ if (v == mkldnn_OIw8o8i) return "OIw8o8i";
+ if (v == mkldnn_Owi16o) return "Owi16o";
+ if (v == mkldnn_Owi4o) return "Owi4o";
+ if (v == mkldnn_Owi8o) return "Owi8o";
+ if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
+ if (v == mkldnn_Ohwi16o) return "Ohwi16o";
+ if (v == mkldnn_Ohwi4o) return "Ohwi4o";
+ if (v == mkldnn_Ohwi8o) return "Ohwi8o";
+ if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
+ if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
+ if (v == mkldnn_Oihw16o) return "Oihw16o";
+ if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
+ if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
+ if (v == mkldnn_Oihw4o) return "Oihw4o";
+ if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
+ if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
+ if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
+ if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
+ if (v == mkldnn_Odhwi16o) return "Odhwi16o";
+ if (v == mkldnn_Odhwi4o) return "Odhwi4o";
+ if (v == mkldnn_Odhwi8o) return "Odhwi8o";
+ if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
+ if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
+ if (v == mkldnn_Oidhw16o) return "Oidhw16o";
+ if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
+ if (v == mkldnn_Oidhw4o) return "Oidhw4o";
+ if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
+ if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
+ if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
+ if (v == mkldnn_Goiw16g) return "Goiw16g";
+ if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
+ if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
+ if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
+ if (v == mkldnn_gOiw16o) return "gOiw16o";
+ if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
+ if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
+ if (v == mkldnn_gOiw4o) return "gOiw4o";
+ if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
+ if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
+ if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
+ if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
+ if (v == mkldnn_gOwi16o) return "gOwi16o";
+ if (v == mkldnn_gOwi4o) return "gOwi4o";
+ if (v == mkldnn_gOwi8o) return "gOwi8o";
+ if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
+ if (v == mkldnn_gOhwi16o) return "gOhwi16o";
+ if (v == mkldnn_gOhwi4o) return "gOhwi4o";
+ if (v == mkldnn_gOhwi8o) return "gOhwi8o";
+ if (v == mkldnn_Goihw16g) return "Goihw16g";
+ if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
+ if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
+ if (v == mkldnn_gOihw16o) return "gOihw16o";
+ if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
+ if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
+ if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
+ if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
+ if (v == mkldnn_gOihw4o) return "gOihw4o";
+ if (v == mkldnn_Goihw8g) return "Goihw8g";
+ if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
+ if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
+ if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
+ if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
+ if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
+ if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
+ if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
+ if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
+ if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
+ if (v == mkldnn_gOidhw16o) return "gOidhw16o";
+ if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
+ if (v == mkldnn_gOidhw4o) return "gOidhw4o";
+ if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
+ if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
+ if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
+ assert(!"unknown fmt_tag");
+ return "unknown fmt_tag";
+}
+
+const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
+ if (v == mkldnn_prop_kind_undef) return "undef";
+ if (v == mkldnn_forward_training) return "forward_training";
+ if (v == mkldnn_forward_inference) return "forward_inference";
+ if (v == mkldnn_forward_scoring) return "forward_scoring";
+ if (v == mkldnn_forward) return "forward";
+ if (v == mkldnn_backward) return "backward";
+ if (v == mkldnn_backward_data) return "backward_data";
+ if (v == mkldnn_backward_weights) return "backward_weights";
+ if (v == mkldnn_backward_bias) return "backward_bias";
+ assert(!"unknown prop_kind");
+ return "unknown prop_kind";
+}
+
+const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
+ if (v == mkldnn_undefined_primitive) return "undef";
+ if (v == mkldnn_reorder) return "reorder";
+ if (v == mkldnn_shuffle) return "shuffle";
+ if (v == mkldnn_concat) return "concat";
+ if (v == mkldnn_sum) return "sum";
+ if (v == mkldnn_convolution) return "convolution";
+ if (v == mkldnn_deconvolution) return "deconvolution";
+ if (v == mkldnn_eltwise) return "eltwise";
+ if (v == mkldnn_softmax) return "softmax";
+ if (v == mkldnn_pooling) return "pooling";
+ if (v == mkldnn_lrn) return "lrn";
+ if (v == mkldnn_batch_normalization) return "batch_normalization";
+ if (v == mkldnn_inner_product) return "inner_product";
+ if (v == mkldnn_rnn) return "rnn";
+ assert(!"unknown prim_kind");
+ return "unknown prim_kind";
+}
+
+const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
+ if (v == mkldnn_alg_kind_undef) return "undef";
+ if (v == mkldnn_convolution_direct) return "convolution_direct";
+ if (v == mkldnn_convolution_winograd) return "convolution_winograd";
+ if (v == mkldnn_convolution_auto) return "convolution_auto";
+ if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
+ if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
+ if (v == mkldnn_eltwise_relu) return "eltwise_relu";
+ if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
+ if (v == mkldnn_eltwise_elu) return "eltwise_elu";
+ if (v == mkldnn_eltwise_square) return "eltwise_square";
+ if (v == mkldnn_eltwise_abs) return "eltwise_abs";
+ if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
+ if (v == mkldnn_eltwise_linear) return "eltwise_linear";
+ if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
+ if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
+ if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
+ if (v == mkldnn_pooling_max) return "pooling_max";
+ if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
+ if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
+ if (v == mkldnn_pooling_avg) return "pooling_avg";
+ if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
+ if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
+ if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
+ if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
+ if (v == mkldnn_vanilla_gru) return "vanilla_gru";
+ if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
+ assert(!"unknown alg_kind");
+ return "unknown alg_kind";
+}
+
+const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
+ if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
+ if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
+ if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
+ if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
+ if (v == mkldnn_unidirectional) return "unidirectional";
+ assert(!"unknown rnn_direction");
+ return "unknown rnn_direction";
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
new file mode 100644
index 0000000000..7e5789e2c3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
@@ -0,0 +1,115 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_THREAD_HPP
+#define MKLDNN_THREAD_HPP
+
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+#define MKLDNN_THR_SEQ 0
+#define MKLDNN_THR_OMP 1
+#define MKLDNN_THR_TBB 2
+
+/* Ideally this condition below should never happen (if the library is built
+ * using regular cmake). For the 3rd-party projects that build the library
+ * from the sources on their own try to guess the right threading... */
+#if !defined(MKLDNN_THR)
+# define MKLDNN_THR MKLDNN_THR_TBB
+#endif
+
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+#define MKLDNN_THR_SYNC 1
+inline int mkldnn_get_max_threads() { return 1; }
+inline int mkldnn_get_num_threads() { return 1; }
+inline int mkldnn_get_thread_num() { return 0; }
+inline int mkldnn_in_parallel() { return 0; }
+inline void mkldnn_thr_barrier() {}
+
+#define PRAGMA_OMP(...)
+
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+#include <omp.h>
+#define MKLDNN_THR_SYNC 1
+
+inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
+inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
+inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
+inline int mkldnn_in_parallel() { return omp_in_parallel(); }
+inline void mkldnn_thr_barrier() {
+# pragma omp barrier
+}
+
+#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
+
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+#include "tbb/task_arena.h"
+#include "tbb/parallel_for.h"
+#define MKLDNN_THR_SYNC 0
+
+inline int mkldnn_get_max_threads()
+{ return tbb::this_task_arena::max_concurrency(); }
+inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
+inline int mkldnn_get_thread_num()
+{ return tbb::this_task_arena::current_thread_index(); }
+inline int mkldnn_in_parallel() { return 0; }
+inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
+
+#define PRAGMA_OMP(...)
+
+#endif
+
+/* MSVC still supports omp 2.0 only */
+#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+# define collapse(x)
+# define PRAGMA_OMP_SIMD(...)
+#else
+# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
+#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
+
+namespace mkldnn {
+namespace impl {
+
+inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
+
+template <typename T, typename U>
+inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
+ T n_min = 1;
+ T &n_my = n_end;
+ if (team <= 1 || n == 0) {
+ n_start = 0;
+ n_my = n;
+ } else if (n_min == 1) {
+ // team = T1 + T2
+ // n = T1*n1 + T2*n2 (n1 - n2 = 1)
+ T n1 = utils::div_up(n, (T)team);
+ T n2 = n1 - 1;
+ T T1 = n - n2 * (T)team;
+ n_my = (T)tid < T1 ? n1 : n2;
+ n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
+ }
+
+ n_end += n_start;
+}
+
+} // namespace impl
+} // namespace mkldnn
+
+#include "mkldnn_thread_parallel_nd.hpp"
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
new file mode 100644
index 0000000000..50f9b29622
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
@@ -0,0 +1,277 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
+#define MKLDNN_THREAD_PARALLEL_ND_HPP
+
+/* This header must be included by mkldnn_thread.hpp only */
+
+/* Functions:
+ * - parallel(nthr, f) - executes f in parallel using at most
+ * nthr threads. If nthr equals 0
+ * mkldnn_get_max_threads() threads is
+ * used
+ * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
+ * created threads
+ * - parallel_nd(dims..., f) - creates a parallel section and then
+ * calls for_nd
+ * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
+ * calls for_nd (mostly for convenience)
+ */
+
+namespace mkldnn {
+namespace impl {
+
+/* general parallelization */
+template <typename F>
+void parallel(int nthr, F f) {
+ if (nthr == 0) nthr = mkldnn_get_max_threads();
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+ assert(nthr == 1);
+ f(0, 1);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+ if (nthr == 1) { f(0, 1); return; }
+# pragma omp parallel num_threads(nthr)
+ f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+ if (nthr == 1) { f(0, 1); return; }
+ tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
+#endif
+}
+
+/* for_nd section */
+
+template <typename T0, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
+ T0 start{0}, end{0};
+ balance211(D0, nthr, ithr, start, end);
+ for (T0 d0 = start; d0 < end; ++d0) f(d0);
+}
+
+template <typename T0, typename T1, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
+ const size_t work_amount = (size_t)D0 * D1;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1);
+ utils::nd_iterator_step(d0, D0, d1, D1);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, const T3 &D3, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2, d3);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, const T3 &D3, const T4 &D4, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2, d3, d4);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename T5, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
+ d5, D5);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2, d3, d4, d5);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
+ }
+}
+
+// Skip a lambda function in the parameter pack.
+template <typename T>
+constexpr size_t get_work_amount(const T &v) { return 1; }
+template <typename T, typename ...Args>
+constexpr size_t get_work_amount(const T &v, Args &&...args)
+{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
+
+/* parallel_nd and parallel_nd_in_omp section */
+
+#if MKLDNN_THR != MKLDNN_THR_TBB
+template <typename ...Args>
+void parallel_nd(Args &&...args) {
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+ for_nd(0, 1, utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+ const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
+# pragma omp parallel if (do_parallel)
+ {
+ const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
+ const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
+ for_nd(ithr, nthr, utils::forward<Args>(args)...);
+ }
+#endif
+}
+#else // MKLDNN_THR != MKLDNN_THR_TBB
+
+// gcc 4.8 has a bug with passing parameter pack to lambdas.
+// So have to explicitly instantiate all the cases.
+
+template <typename T0, typename F>
+void parallel_nd(const T0 &D0, F f) {
+ const size_t work_amount = (size_t)D0;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(T0(iwork));
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, F f) {
+ const size_t work_amount = (size_t)D0 * D1;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1);
+ utils::nd_iterator_step(d0, D0, d1, D1);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2, d3);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
+ const T4 &D4, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2, d3, d4);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename T5, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
+ const T4 &D4, const T5 &D5, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
+ d5, D5);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2, d3, d4, d5);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
+ }
+ }, tbb::static_partitioner());
+}
+#endif
+
+template <typename ...Args>
+void parallel_nd_in_omp(Args &&...args) {
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+ for_nd(0, 1, utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+ for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
+ utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+ assert(!"unsupported parallel_nd_in_omp()");
+#endif
+}
+
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
new file mode 100644
index 0000000000..aa671a0b6e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
@@ -0,0 +1,77 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_TRAITS_HPP
+#define MKLDNN_TRAITS_HPP
+
+#include <assert.h>
+#include <stdint.h>
+
+#include "mkldnn.h"
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+template <data_type_t> struct prec_traits {}; /* ::type -> float */
+template <typename> struct data_traits {}; /* ::data_type -> f32 */
+template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
+template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
+
+template <> struct prec_traits<data_type::f32> { typedef float type; };
+template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
+template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
+template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
+
+template <> struct data_traits<float>
+{ static constexpr data_type_t data_type = data_type::f32; };
+template <> struct data_traits<int32_t>
+{ static constexpr data_type_t data_type = data_type::s32; };
+template <> struct data_traits<int8_t>
+{ static constexpr data_type_t data_type = data_type::s8; };
+template <> struct data_traits<uint8_t>
+{ static constexpr data_type_t data_type = data_type::u8; };
+
+template <> struct typesize_traits<4> { typedef float type; };
+template <> struct typesize_traits<2> { typedef int16_t type; };
+template <> struct typesize_traits<1> { typedef uint8_t type; };
+
+#define PKIND_TRAITS_INST(op) \
+template <> struct pkind_traits<primitive_kind::op> { \
+ typedef CONCAT2(op, _desc_t) desc_type; \
+ static constexpr query_t query_d = query::CONCAT2(op, _d); \
+}
+PKIND_TRAITS_INST(convolution);
+PKIND_TRAITS_INST(deconvolution);
+PKIND_TRAITS_INST(shuffle);
+PKIND_TRAITS_INST(eltwise);
+PKIND_TRAITS_INST(softmax);
+PKIND_TRAITS_INST(pooling);
+PKIND_TRAITS_INST(lrn);
+PKIND_TRAITS_INST(batch_normalization);
+PKIND_TRAITS_INST(inner_product);
+PKIND_TRAITS_INST(rnn);
+#undef PKIND_TRAITS_INST
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
new file mode 100644
index 0000000000..f89ea999e2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
@@ -0,0 +1,193 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef NSTL_HPP
+#define NSTL_HPP
+
+#include <stdint.h>
+#include <limits.h>
+#include <float.h>
+
+#include <vector>
+#include <map>
+
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+void *malloc(size_t size, int alignment);
+void free(void *p);
+
+struct c_compatible {
+ enum { default_alignment = 64 };
+ static void *operator new(size_t sz) {
+ return malloc(sz, default_alignment);
+ }
+ static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
+ static void *operator new[](size_t sz) {
+ return malloc(sz, default_alignment);
+ }
+ static void operator delete(void *p) { free(p); }
+ static void operator delete[](void *p) { free(p); }
+};
+
+namespace nstl {
+
+template<typename T>
+inline const T abs(const T& a) {
+ return a >= 0 ? a : -a;
+}
+
+template<typename T>
+inline const T& max(const T& a, const T& b) {
+ return a > b ? a : b;
+}
+
+template<typename T>
+inline const T& min(const T& a, const T& b) {
+ return a < b ? a : b;
+}
+
+template<typename T> void swap(T& t1, T& t2) {
+ T tmp(t1);
+ t1 = t2;
+ t2 = tmp;
+}
+
+// Rationale: MKL-DNN needs numeric limits implementation that does not
+// generate dependencies on C++ run-time libraries.
+
+template<typename T> struct numeric_limits;
+
+template<> struct numeric_limits<float> {
+ static constexpr float lowest() { return -FLT_MAX; }
+ static constexpr float max() { return FLT_MAX; }
+};
+
+template<> struct numeric_limits<int32_t> {
+ static constexpr int lowest() { return INT32_MIN; }
+ static constexpr int max() { return INT32_MAX; }
+};
+
+template<> struct numeric_limits<int16_t> {
+ static constexpr int16_t lowest() { return INT16_MIN; }
+ static constexpr int16_t max() { return INT16_MAX; }
+};
+
+template<> struct numeric_limits<int8_t> {
+ static constexpr int8_t lowest() { return INT8_MIN; }
+ static constexpr int8_t max() { return INT8_MAX; }
+};
+
+template<> struct numeric_limits<uint8_t> {
+ static constexpr uint8_t lowest() { return 0; }
+ static constexpr uint8_t max() { return UINT8_MAX; }
+};
+
+template<typename T> struct is_integral
+{ static constexpr bool value = false; };
+template<> struct is_integral<int32_t> { static constexpr bool value = true; };
+template<> struct is_integral<int16_t> { static constexpr bool value = true; };
+template<> struct is_integral<int8_t> { static constexpr bool value = true; };
+template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
+
+template <typename T, typename U> struct is_same
+{ static constexpr bool value = false; };
+template <typename T> struct is_same<T, T>
+{ static constexpr bool value = true; };
+
+// Rationale: MKL-DNN needs container implementations that do not generate
+// dependencies on C++ run-time libraries.
+//
+// Implementation philosophy: caller is responsible to check if the operation
+// is valid. The only functions that have to return status are those that
+// depend on memory allocation or similar operations.
+//
+// This means that e.g. an operator [] does not have to check for boundaries.
+// The caller should have checked the boundaries. If it did not we crash and
+// burn: this is a bug in MKL-DNN and throwing an exception would not have been
+// recoverable.
+//
+// On the other hand, insert() or resize() or a similar operation needs to
+// return a status because the outcome depends on factors external to the
+// caller. The situation is probably also not recoverable also, but MKL-DNN
+// needs to be nice and report "out of memory" to the users.
+
+enum nstl_status_t {
+ success = 0,
+ out_of_memory
+};
+
+template <typename T> class vector: public c_compatible {
+private:
+ std::vector<T> _impl;
+public:
+ typedef typename std::vector<T>::iterator iterator;
+ typedef typename std::vector<T>::const_iterator const_iterator;
+ typedef typename std::vector<T>::size_type size_type;
+ vector() {}
+ vector(size_type n): _impl(n) {}
+ vector(size_type n, const T &value): _impl(n, value) {}
+ template <typename input_iterator>
+ vector(input_iterator first, input_iterator last): _impl(first, last) {}
+ ~vector() {}
+ size_type size() const { return _impl.size(); }
+ T& operator[] (size_type i) { return _impl[i]; }
+ const T& operator[] (size_type i) const { return _impl[i]; }
+ iterator begin() { return _impl.begin(); }
+ const_iterator begin() const { return _impl.begin(); }
+ iterator end() { return _impl.end(); }
+ const_iterator end() const { return _impl.end(); }
+ template <typename input_iterator>
+ nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
+ {
+ _impl.insert(pos, begin, end);
+ return success;
+ }
+ void clear() { _impl.clear(); }
+ void push_back(const T& t) { _impl.push_back(t); }
+ void resize(size_type count) { _impl.resize(count); }
+ void reserve(size_type count) { _impl.reserve(count); }
+};
+
+template <typename Key, typename T> class map: public c_compatible {
+private:
+ std::map<Key, T> _impl;
+public:
+ typedef typename std::map<Key, T>::iterator iterator;
+ typedef typename std::map<Key, T>::const_iterator const_iterator;
+ typedef typename std::map<Key, T>::size_type size_type;
+ map() {}
+ ~map() {}
+ size_type size() const { return _impl.size(); }
+ T& operator[](const Key &k) { return _impl[k]; }
+ const T& operator[](const Key &k) const { return _impl[k]; }
+ iterator begin() { return _impl.begin(); }
+ const_iterator begin() const { return _impl.begin(); }
+ iterator end() { return _impl.end(); }
+ const_iterator end() const { return _impl.end(); }
+ template <typename input_iterator>
+ void clear() { _impl.clear(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
new file mode 100644
index 0000000000..be96e654ff
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
@@ -0,0 +1,114 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t pooling_desc_init(pooling_desc_t *pool_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t kernel, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ bool args_ok = true
+ && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
+ && one_of(alg_kind, pooling_max,
+ pooling_avg_include_padding,
+ pooling_avg_exclude_padding)
+ && one_of(padding_kind, padding_kind::padding_zero);
+ if (!args_ok) return invalid_arguments;
+
+ if (padding_r == nullptr) padding_r = padding_l;
+
+ auto pd = pooling_desc_t();
+ pd.primitive_kind = primitive_kind::pooling;
+ pd.prop_kind = prop_kind;
+ pd.alg_kind = alg_kind;
+ pd.src_desc.ndims = src_desc->ndims;
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+
+ pd.diff_src_desc = pd.src_desc = zero_md();
+ pd.diff_dst_desc = pd.dst_desc = zero_md();
+
+ (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
+ (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
+
+ int sp_dims = src_desc->ndims - 2;
+ utils::array_copy(pd.strides, strides, sp_dims);
+ utils::array_copy(pd.kernel, kernel, sp_dims);
+ utils::array_copy(pd.padding[0], padding_l, sp_dims);
+ utils::array_copy(pd.padding[1], padding_r, sp_dims);
+
+ pd.padding_kind = padding_kind;
+ if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
+ pooling_avg_exclude_padding)) {
+ pd.accum_data_type = types::default_accum_data_type(
+ src_desc->data_type, dst_desc->data_type);
+ } else {
+ pd.accum_data_type = dst_desc->data_type;
+ }
+
+ bool consistency = true
+ && utils::one_of(src_desc->ndims, 4, 5)
+ && utils::one_of(dst_desc->ndims, 4, 5)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && src_desc->dims[1] == dst_desc->dims[1];
+ for (int i = 2; i < src_desc->ndims; ++i)
+ consistency = consistency && (
+ (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
+ + padding_r[i - 2]) / strides[i - 2] + 1
+ == dst_desc->dims[i]);
+ if (!consistency) return invalid_arguments;
+
+ *pool_desc = pd;
+ return success;
+}
+}
+
+status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t kernel, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
+ dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
+ alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
+ diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
+ padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
new file mode 100644
index 0000000000..4c9c009412
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
@@ -0,0 +1,238 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef POOLING_PD_HPP
+#define POOLING_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct pooling_fwd_pd_t;
+
+struct pooling_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::pooling;
+
+ pooling_pd_t(engine_t *engine,
+ const pooling_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const pooling_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , ws_md_()
+ {}
+
+ const pooling_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::pooling_d:
+ *(const pooling_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common pooling aux functions */
+
+ dim_t MB() const { return src_desc().dims[0]; }
+ dim_t C() const { return src_desc().dims[1]; }
+
+ dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
+ dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
+ dim_t IW() const { return src_desc().dims[ndims() - 1]; }
+
+ dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
+ dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
+ dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
+
+ dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
+ dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
+ dim_t KW() const { return desc_.kernel[ndims() - 3]; }
+
+ dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+ dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+ dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+ dim_t padFront() const
+ { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+ dim_t padBack() const
+ { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+ dim_t padT() const
+ { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+ dim_t padB() const
+ { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+ dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+ dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+ int ndims() const { return src_desc().ndims; }
+ bool is_3d() const { return ndims() == 5; }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(src_desc()).has_zero_dim(); }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ pooling_desc_t desc_;
+ const pooling_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t ws_md_;
+
+ void init_default_ws() {
+ ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
+ ws_md_.data_type = indices_data_type();
+ }
+
+ data_type_t indices_data_type() const {
+ /* the simplest way to express 256... */
+ const int u8_max = nstl::numeric_limits<
+ typename prec_traits<data_type::u8>::type>::max();
+ return utils::array_product(desc()->kernel, ndims()) <= u8_max
+ ? data_type::u8 : data_type::s32;
+ }
+
+private:
+ const memory_desc_t &src_desc() const
+ { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
+ const memory_desc_t &dst_desc() const
+ { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
+};
+
+struct pooling_fwd_pd_t: public pooling_pd_t {
+ typedef pooling_fwd_pd_t base_class;
+ typedef pooling_fwd_pd_t hint_class;
+
+ pooling_fwd_pd_t(engine_t *engine,
+ const pooling_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const pooling_fwd_pd_t *hint_fwd_pd)
+ : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t dst_md_;
+
+ virtual status_t set_default_params() {
+ if (dst_md()->format_kind != format_kind::any)
+ return status::success;
+
+ if (src_md()->format_kind != format_kind::blocked)
+ return status::unimplemented;
+
+ return memory_desc_init_by_blocking_desc(dst_md_,
+ src_md_.format_desc.blocking);
+ }
+};
+
+struct pooling_bwd_pd_t: public pooling_pd_t {
+ typedef pooling_bwd_pd_t base_class;
+ typedef pooling_fwd_pd_t hint_class;
+
+ pooling_bwd_pd_t(engine_t *engine,
+ const pooling_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const pooling_fwd_pd_t *hint_fwd_pd)
+ : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_DIFF_DST)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::input;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t diff_dst_md_;
+
+ virtual status_t set_default_params() {
+ if (diff_src_md()->format_kind != format_kind::any)
+ return status::success;
+
+ if (diff_dst_md()->format_kind != format_kind::blocked)
+ return status::unimplemented;
+
+ return memory_desc_init_by_blocking_desc(diff_src_md_,
+ diff_dst_md_.format_desc.blocking);
+ }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
new file mode 100644
index 0000000000..fdf6522f62
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
@@ -0,0 +1,103 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "primitive.hpp"
+#include "type_helpers.hpp"
+#include "stream.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::primitive_kind;
+
+namespace {
+// XXX: this is a huge hammer. This disables all and any msan checks on
+// primitives outputs.
+//
+// A proper approach would be an implementation-specific unpoisoning.
+void unpoison_outputs(const exec_args_t &args) {
+ for(const auto &arg: args) {
+ if (arg.second.is_const) continue;
+ auto *mem = arg.second.mem;
+ void *p;
+ mem->get_data_handle(&p);
+ size_t s = memory_desc_wrapper(*mem->md()).size();
+ msan_unpoison(p, s);
+ }
+}
+}
+
+status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
+ if (primitive_desc) delete primitive_desc;
+ return success;
+}
+
+status_t mkldnn_primitive_create(primitive_t **primitive,
+ const primitive_desc_t *primitive_desc) {
+ if (utils::any_null(primitive, primitive_desc))
+ return invalid_arguments;
+ return primitive_desc->create_primitive(primitive);
+}
+
+status_t mkldnn_primitive_execute(const primitive_t *primitive,
+ stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
+ bool ok = true
+ && !utils::any_null(primitive, stream)
+ && primitive->engine() == stream->engine()
+ && IMPLICATION(nargs > 0, c_args != nullptr);
+ if (!ok) return invalid_arguments;
+
+ exec_args_t args;
+ status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
+ if (status != status::success) return status;
+
+ exec_ctx_t ctx(stream, std::move(args));
+
+ if (mkldnn_verbose()->level) {
+ double ms = get_msec();
+ status = primitive->execute(ctx);
+ ms = get_msec() - ms;
+ printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
+ fflush(0);
+ } else {
+ status = primitive->execute(ctx);
+ }
+
+ if (msan_enabled) unpoison_outputs(ctx.args());
+
+ return status;
+}
+
+status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
+ const primitive_desc_t **primitive_desc) {
+ if (utils::any_null(primitive, primitive_desc))
+ return invalid_arguments;
+ return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
+ primitive->pd());
+}
+
+status_t mkldnn_primitive_destroy(primitive_t *primitive) {
+ if (primitive != nullptr)
+ delete primitive;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
new file mode 100644
index 0000000000..3b506d6d1f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
@@ -0,0 +1,76 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_HPP
+#define PRIMITIVE_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "primitive_exec_types.hpp"
+
+/** \brief A pure virtual primitive class
+ *
+ * Primitive contains links to its inputs & outputs, though it does not track
+ * their readiness on execution step.
+ *
+ * @remark @b Rational.
+ * Dependencies are essential through-out the whole MKL-DNN library, so it
+ * makes sense to include them on the very low level. On the other hand,
+ * tracking them should be a task for corresponding essence, like scheduler,
+ * stream or whatever. Primitive itself should know nothing about the
+ * environment it is running in.
+ *
+ * @note
+ * To make user experience better we should provide API which allows
+ * achieving the best (or good enough) performance when creating primitives
+ * in natural order: i.e. from bottom to top for forward pass and from top to
+ * bottom for backward pass. Please consider restriction [1] in Level 0.
+ */
+struct mkldnn_primitive: public mkldnn::impl::c_compatible {
+ mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
+ : pd_(pd->clone()) {}
+ virtual ~mkldnn_primitive() { delete pd_; }
+
+ /** returns primitive's engine */
+ mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
+ /** returns primitive's inputs */
+ const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
+ /** returns primitive's kind */
+ mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
+
+ /** executes primitive with execution context @p ctx */
+ virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
+ const = 0;
+
+protected:
+ const mkldnn::impl::primitive_desc_t *pd_;
+
+private:
+ mkldnn_primitive() = delete;
+ mkldnn_primitive(const mkldnn_primitive &) = delete;
+ mkldnn_primitive(mkldnn_primitive &&) = delete;
+ mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
+ mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
+};
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
new file mode 100644
index 0000000000..9fd638842c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
@@ -0,0 +1,290 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_attr.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+namespace mkldnn {
+namespace impl {
+
+status_t scales_t::set(dim_t count, int mask, const float *scales) {
+ cleanup();
+
+ count_ = count;
+ mask_ = mask;
+
+ if (count_ == 1) {
+ scales_ = scales_buf_;
+ utils::array_set(scales_, scales[0], scales_buf_size);
+ } else {
+ scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
+ if (scales_ == nullptr)
+ return status::out_of_memory;
+
+ for (dim_t c = 0; c < count_; ++c)
+ scales_[c] = scales[c];
+ }
+
+ return status::success;
+}
+
+}
+}
+
+status_t post_ops_t::append_sum(float scale) {
+ if (len_ == capacity)
+ return out_of_memory;
+
+ entry_[len_].kind = primitive_kind::sum;
+ entry_[len_].sum.scale = scale;
+
+ len_++;
+
+ return success;
+}
+
+status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
+ float beta) {
+ using namespace mkldnn::impl::alg_kind;
+ bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
+ if (!known_alg)
+ return invalid_arguments;
+
+ if (len_ == capacity)
+ return out_of_memory;
+
+ entry_[len_].kind = primitive_kind::eltwise;
+ entry_[len_].eltwise.scale = scale;
+ entry_[len_].eltwise.alg = alg;
+ entry_[len_].eltwise.alpha = alpha;
+ entry_[len_].eltwise.beta = beta;
+
+ len_++;
+
+ return success;
+}
+
+status_t primitive_attr_t::set_scratchpad_mode(
+ scratchpad_mode_t scratchpad_mode) {
+ using namespace mkldnn::impl::scratchpad_mode;
+
+ const bool ok = one_of(scratchpad_mode, library, user);
+ if (!ok)
+ return invalid_arguments;
+
+ scratchpad_mode_ = scratchpad_mode;
+ return success;
+}
+
+status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
+ this->post_ops_ = post_ops;
+ return success;
+}
+
+/* Public C API */
+
+status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
+ if (attr == nullptr)
+ return invalid_arguments;
+
+ return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
+ new mkldnn_primitive_attr);
+}
+
+status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
+ const primitive_attr_t *existing_attr) {
+ if (any_null(attr, existing_attr))
+ return invalid_arguments;
+
+ return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
+ existing_attr->clone());
+}
+
+status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
+ if (attr)
+ delete attr;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_get_scratchpad_mode(
+ const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
+ if (any_null(attr, scratchpad_mode))
+ return invalid_arguments;
+
+ *scratchpad_mode = attr->scratchpad_mode_;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_scratchpad_mode(
+ primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
+ if (any_null(attr))
+ return invalid_arguments;
+
+ return attr->set_scratchpad_mode(scratchpad_mode);
+}
+
+status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
+ dim_t *count, int *mask, const float **scales) {
+ if (any_null(attr, count, mask, scales))
+ return invalid_arguments;
+
+ *count = attr->output_scales_.count_;
+ *mask = attr->output_scales_.mask_;
+ *scales = attr->output_scales_.scales_;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
+ dim_t count, int mask, const float *scales) {
+ bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
+ if (!ok)
+ return invalid_arguments;
+
+ return attr->output_scales_.set(count, mask, scales);
+}
+
+status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
+ const post_ops_t **post_ops) {
+ if (any_null(attr, post_ops))
+ return invalid_arguments;
+
+ *post_ops = &attr->post_ops_;
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
+ const post_ops_t *post_ops) {
+ if (any_null(attr, post_ops))
+ return invalid_arguments;
+
+ return attr->set_post_ops(*post_ops);
+}
+
+status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
+ if (post_ops == nullptr)
+ return invalid_arguments;
+
+ return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
+}
+
+status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
+ if (post_ops)
+ delete post_ops;
+
+ return success;
+}
+
+int mkldnn_post_ops_len(const post_ops_t *post_ops) {
+ if (post_ops)
+ return post_ops->len_;
+
+ return 0;
+}
+
+primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
+ int index) {
+ bool ok = post_ops && 0 <= index && index < post_ops->len_;
+ if (!ok)
+ return primitive_kind::undefined;
+
+ return post_ops->entry_[index].kind;
+}
+
+status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
+ if (post_ops == nullptr)
+ return invalid_arguments;
+
+ return post_ops->append_sum(scale);
+}
+
+namespace {
+bool simple_get_params_check(const post_ops_t *post_ops, int index,
+ primitive_kind_t kind) {
+ bool ok = true
+ && post_ops != nullptr
+ && 0 <= index
+ && index < post_ops->len_
+ && post_ops->entry_[index].kind == kind;
+ return ok;
+}
+}
+
+status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
+ float *scale) {
+ bool ok = true
+ && simple_get_params_check(post_ops, index, primitive_kind::sum)
+ && !any_null(scale);
+ if (!ok)
+ return invalid_arguments;
+
+ *scale = post_ops->entry_[index].sum.scale;
+ return success;
+}
+
+status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
+ alg_kind_t kind, float alpha, float beta) {
+ if (post_ops == nullptr)
+ return invalid_arguments;
+
+ return post_ops->append_eltwise(scale, kind, alpha, beta);
+}
+
+status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
+ int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
+ bool ok = true
+ && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
+ && !any_null(scale, alpha, beta);
+ if (!ok)
+ return invalid_arguments;
+
+ const auto &e = post_ops->entry_[index].eltwise;
+ *scale = e.scale;
+ *alg = e.alg;
+ *alpha = e.alpha;
+ *beta = e.beta;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_rnn_data_qparams(
+ primitive_attr_t *attr, const float scale, const float shift) {
+ if (attr == nullptr)
+ return invalid_arguments;
+
+ return attr->rnn_data_qparams_.set(scale, shift);
+}
+
+status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
+ primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
+ bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
+ if (!ok)
+ return invalid_arguments;
+
+ return attr->rnn_weights_qparams_.set(count, mask, scales);
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
new file mode 100644
index 0000000000..e2130c7ab1
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
@@ -0,0 +1,183 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_ATTR_HPP
+#define PRIMITIVE_ATTR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_data_qparams_t : public c_compatible {
+ rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
+ bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
+
+ status_t set(float scale, float shift) {
+ scale_ = scale;
+ shift_ = shift;
+ return status::success;
+ }
+
+ float scale_;
+ float shift_;
+};
+
+struct scales_t: public c_compatible {
+ scales_t(): count_(1), mask_(0), scales_(scales_buf_)
+ { set(1.); }
+
+ scales_t(const scales_t &rhs): scales_t()
+ { set(rhs.count_, rhs.mask_, rhs.scales_); }
+
+ ~scales_t() { cleanup(); }
+
+ scales_t &operator=(const scales_t &rhs) {
+ if (&rhs == this)
+ return *this;
+ status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
+ assert(status == status::success);
+ (void)status;
+ return *this;
+ }
+
+ bool has_default_values() const {
+ for (dim_t c = 0; c < count_; ++c) {
+ if(scales_[c] != 1.) return false;
+ }
+ return true;
+ }
+
+ status_t set(dim_t count, int mask, const float *scales);
+ status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
+
+ dim_t count_;
+ int mask_;
+ float *scales_;
+
+private:
+ enum { scales_buf_size = 16 };
+ float scales_buf_[scales_buf_size];
+
+ void cleanup() {
+ if (scales_ != scales_buf_ && scales_ != nullptr)
+ impl::free(scales_);
+
+ count_ = 1;
+ mask_ = 0;
+ scales_ = scales_buf_;
+ }
+};
+
+}
+}
+
+struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
+ struct entry_t {
+ struct eltwise_t {
+ mkldnn::impl::alg_kind_t alg;
+ float scale, alpha, beta;
+ };
+
+ mkldnn::impl::primitive_kind_t kind;
+ union {
+ struct { float scale; } sum;
+ eltwise_t eltwise;
+ };
+
+ bool is_eltwise(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::eltwise
+ && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
+ }
+
+ bool is_relu(bool require_scale_one = true,
+ bool require_nslope_zero = true) const {
+ using namespace mkldnn::impl;
+ return is_eltwise(require_scale_one)
+ && eltwise.alg == alg_kind::eltwise_relu
+ && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
+ }
+
+ bool is_sum(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::sum
+ && IMPLICATION(require_scale_one, sum.scale == 1.f);
+ }
+ };
+
+ mkldnn_post_ops(): len_(0) {}
+
+ mkldnn::impl::status_t append_sum(float scale);
+ mkldnn::impl::status_t append_eltwise(float scale,
+ mkldnn::impl::alg_kind_t alg, float alpha, float beta);
+
+ int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
+ int stop = -1) const {
+ if (stop == -1) stop = len_;
+ stop = mkldnn::impl::nstl::min(stop, len_);
+ for (int idx = start; idx < stop; ++idx)
+ if (entry_[idx].kind == kind) return idx;
+ return -1;
+ }
+
+ bool has_default_values() const { return len_ == 0; }
+
+ bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
+ { return find(kind, index, index + 1) == index; }
+
+ enum { capacity = 4 };
+
+ int len_;
+ entry_t entry_[capacity];
+};
+
+struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
+ mkldnn_primitive_attr()
+ : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
+ {}
+
+ mkldnn_primitive_attr *clone() const
+ { return new mkldnn_primitive_attr(*this); }
+
+ /** Returns true if the attributes have default values.
+ *
+ * @note The scratchpad_mode_ is not take into account */
+ bool has_default_values() const {
+ return true
+ && output_scales_.has_default_values()
+ && post_ops_.has_default_values()
+ && rnn_data_qparams_.has_default_values()
+ && rnn_weights_qparams_.has_default_values();
+ }
+
+ mkldnn::impl::status_t set_scratchpad_mode(
+ mkldnn::impl::scratchpad_mode_t scratchpad_mode);
+ mkldnn::impl::status_t set_post_ops(
+ const mkldnn::impl::post_ops_t &post_ops);
+
+ mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
+ mkldnn::impl::scales_t output_scales_;
+ mkldnn::impl::post_ops_t post_ops_;
+ mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
+ mkldnn::impl::scales_t rnn_weights_qparams_;
+};
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
new file mode 100644
index 0000000000..723c41e05a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
@@ -0,0 +1,78 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
+ auto safe_ret_md = [&](const memory_desc_t *_) {
+ if (_ == nullptr) return not_required;
+ *(const memory_desc_t **)result = _;
+ return success;
+ };
+
+ switch (what) {
+ case query::engine: *(engine_t**)result = engine(); break;
+ case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
+
+ case query::scratchpad_engine:
+ *(engine_t**)result = scratchpad_engine(); break;
+
+ case query::memory_consumption_s64:
+ *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
+
+ case query::op_d:
+ if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
+ *(const_c_op_desc_t *)result
+ = static_cast<const_c_op_desc_t>(op_desc()); break;
+
+ case query::src_md: return safe_ret_md(src_md(idx));
+ case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
+ case query::dst_md: return safe_ret_md(dst_md(idx));
+ case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
+ case query::weights_md: return safe_ret_md(weights_md(idx));
+ case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
+ case query::workspace_md:
+ if (idx != 0) return status::invalid_arguments;
+ return safe_ret_md(workspace_md(idx));
+ case query::scratchpad_md:
+ if (idx != 0) return status::invalid_arguments;
+ return safe_ret_md(scratchpad_md(idx));
+
+ case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
+ case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
+
+ case query::impl_info_str: *(const char **)result = name(); break;
+
+ default: return unimplemented;
+ }
+ return success;
+}
+
+status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
+ const primitive_attr_t **attr) {
+ if (utils::any_null(primitive_desc, attr))
+ return invalid_arguments;
+
+ *attr = primitive_desc->attr();
+ return success;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
new file mode 100644
index 0000000000..536dcfa1d0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
@@ -0,0 +1,174 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_DESC_HPP
+#define PRIMITIVE_DESC_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "primitive_attr.hpp"
+#include "verbose.hpp"
+
+struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
+ using md_t = mkldnn::impl::memory_desc_t;
+
+ mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::primitive_kind_t kind)
+ : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
+
+ mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
+ mkldnn::impl::primitive_kind_t kind)
+ : engine_(engine), kind_(kind) { info_[0] = '\0'; }
+
+ virtual mkldnn_primitive_desc *clone() const = 0;
+ virtual ~mkldnn_primitive_desc() {}
+
+ const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
+ mkldnn::impl::engine_t *engine() const { return engine_; }
+ mkldnn::impl::primitive_kind_t kind() const { return kind_; }
+
+ virtual void init_info() {}
+ const char *info() const { return info_; }
+
+ mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
+ { return scratchpad_registry_; }
+ const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
+ { return scratchpad_registry_; }
+ virtual mkldnn::impl::engine_t *scratchpad_engine() const
+ { return engine_; }
+
+ virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
+
+ enum class arg_usage_t { unused, input, output };
+ virtual arg_usage_t arg_usage(
+ mkldnn::impl::primitive_arg_index_t arg) const {
+ using mkldnn::impl::types::is_zero_md;
+ if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
+ return arg_usage_t::output;
+ return arg_usage_t::unused;
+ }
+
+# define DECLARE_MD_STUB(stub) \
+ virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
+ { return nullptr; }
+
+ DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
+ DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
+ DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
+ DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
+ DECLARE_MD_STUB(workspace_md);
+# undef DECLARE_MD_STUB
+
+ const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
+ return idx == 0 ? &scratchpad_md_ : nullptr;
+ }
+
+ virtual void init_scratchpad_md() {
+ auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
+ mkldnn::impl::dims_t dims = { size };
+ mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
+ mkldnn::impl::data_type::u8, mkldnn_x);
+ }
+
+ /** returns the scratchpad size for the given scratchpad mode. */
+ mkldnn::impl::dim_t scratchpad_size(
+ mkldnn::impl::scratchpad_mode_t mode) const {
+ if (mode != attr_.scratchpad_mode_) return 0;
+ return scratchpad_registry().size();
+ }
+
+ virtual int n_inputs() const { return 0; }
+ virtual int n_outputs() const { return 0; }
+
+ virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
+ void *result) const;
+
+ virtual mkldnn::impl::status_t create_primitive(
+ mkldnn::impl::primitive_t **primitive) const = 0;
+
+ virtual const char *name() const { return "mkldnn_primitive_desc"; }
+
+ /* static magic */
+
+ template<typename pd_t>
+ static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
+ const mkldnn::impl::op_desc_t *adesc,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_desc_t *hint_fwd) {
+ using namespace mkldnn::impl;
+ using namespace mkldnn::impl::status;
+ using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
+ if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
+ assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
+ auto hint =
+ reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
+ auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
+ if (_pd == nullptr) return out_of_memory;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ _pd->init_info();
+ _pd->init_scratchpad_md();
+ *pd = _pd;
+ return success;
+ }
+
+protected:
+ mkldnn::impl::engine_t *engine_;
+ mkldnn::impl::primitive_attr_t attr_;
+ mkldnn::impl::primitive_kind_t kind_;
+
+ mkldnn::impl::memory_desc_t scratchpad_md_;
+
+ char info_[MKLDNN_VERBOSE_BUF_LEN];
+
+ mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
+
+protected:
+ /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
+ * Expectation: this already set workspace, and this workspace should
+ * exactly match the one from fwd_pd */
+ bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
+ using namespace mkldnn::impl;
+ if (!workspace_md()) return true; // the impl lives fine w/o workspace
+ return fwd_pd && fwd_pd->workspace_md()
+ && *fwd_pd->workspace_md() == *workspace_md();
+ }
+};
+
+#define DECLARE_COMMON_PD_t(impl_name, ...) \
+ virtual pd_t *clone() const override { return new pd_t(*this); } \
+ virtual status_t create_primitive(primitive_t **p) const override { \
+ double ms = get_msec(); \
+ auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+ ms = get_msec() - ms; \
+ if (mkldnn_verbose()->level >= 2) { \
+ printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+ fflush(0); \
+ } \
+ return ret; \
+ } \
+ virtual const char *name() const override { return impl_name; }
+#define DECLARE_COMMON_PD_T(impl_name, ...) \
+ DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
new file mode 100644
index 0000000000..43e5a31ef3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
@@ -0,0 +1,90 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "memory.hpp"
+#include "primitive.hpp"
+#include "primitive_exec_types.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
+ const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
+ using namespace status;
+
+ if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
+
+ int n_inputs = 0;
+ int n_outputs = 0;
+
+ for (int i = 0; i < nargs; ++i) {
+ primitive_arg_index_t arg = c_args[i].arg;
+ auto *mem = c_args[i].memory;
+
+ switch (pd->arg_usage(arg)) {
+ case primitive_desc_t::arg_usage_t::input:
+ if (args.count(arg) != 0) return invalid_arguments;
+ args[arg] = {mem, true};
+ n_inputs++;
+ break;
+ case primitive_desc_t::arg_usage_t::output:
+ if (args.count(arg) != 0) return invalid_arguments;
+ args[arg] = {mem, false};
+ n_outputs++;
+ break;
+ case primitive_desc_t::arg_usage_t::unused:
+ break;
+ }
+ }
+
+ bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
+
+ if (n_inputs != pd->n_inputs()) return invalid_arguments;
+ if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
+ return invalid_arguments;
+
+ return success;
+}
+
+const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
+ if (args_.count(arg) != 1) return nullptr;
+ const auto ma = args_.at(arg);
+ assert(ma.is_const);
+ void *ptr;
+ status_t status = ma.mem->get_data_handle(&ptr);
+ assert(status == status::success); MAYBE_UNUSED(status);
+ return ptr;
+}
+
+void *exec_ctx_t::output(primitive_arg_index_t arg) const {
+ if (args_.count(arg) != 1) return nullptr;
+ const auto ma = args_.at(arg);
+ assert(!ma.is_const);
+ void *ptr;
+ status_t status = ma.mem->get_data_handle(&ptr);
+ assert(status == status::success); MAYBE_UNUSED(status);
+ return ptr;
+}
+
+const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
+ assert(args_.count(arg) == 1);
+ const auto ma = args_.at(arg);
+ assert(!ma.is_const);
+ return ma.mem;
+}
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
new file mode 100644
index 0000000000..0645891da7
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_EXEC_TYPES_HPP
+#define PRIMITIVE_EXEC_TYPES_HPP
+
+#include <unordered_map>
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "memory.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct memory_arg_t {
+ memory_t *mem;
+ bool is_const;
+};
+
+using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
+
+status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
+ const mkldnn_exec_arg_t *c_args, exec_args_t &args);
+
+/** Primitive execution context (helps passing stream, memories, and events. */
+struct exec_ctx_t {
+ exec_ctx_t(const exec_ctx_t &) = default;
+ exec_ctx_t(exec_ctx_t &&) = default;
+
+ exec_ctx_t(stream_t *stream): stream_(stream) {}
+ exec_ctx_t(stream_t *stream, exec_args_t &&args)
+ : stream_(stream)
+ , args_(std::move(args)) {}
+
+ stream_t *stream() const { return stream_; }
+ const exec_args_t &args() const { return args_; }
+
+ /* tentative solution... TODO: replace with functions return memory_t */
+ const void *input(primitive_arg_index_t arg) const;
+ void *output(primitive_arg_index_t arg) const;
+
+ const memory_t *memory(primitive_arg_index_t arg) const;
+
+private:
+ stream_t *stream_;
+ exec_args_t args_;
+};
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
new file mode 100644
index 0000000000..5a1cd7d379
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
@@ -0,0 +1,89 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "primitive_iterator.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_primitive_desc_iterator_create(
+ primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc,
+ const primitive_attr_t *attr, engine_t *engine,
+ const primitive_desc_t *hint_fwd_pd) {
+ const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
+
+ auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd);
+ if (it == nullptr) return out_of_memory;
+
+ ++(*it);
+ if (*it == it->end()) {
+ delete it;
+ return unimplemented;
+ }
+
+ *iterator = it;
+ return success;
+}
+
+status_t mkldnn_primitive_desc_iterator_next(
+ primitive_desc_iterator_t *iterator) {
+ if (iterator == nullptr) return invalid_arguments;
+ ++(*iterator);
+ return *iterator == iterator->end() ? iterator_ends : success;
+}
+
+primitive_desc_t *mkldnn_primitive_desc_iterator_fetch(
+ const primitive_desc_iterator_t *iterator) {
+ if (iterator == nullptr) return nullptr;
+ return *(*iterator);
+}
+
+status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc,
+ const primitive_desc_t *existing_primitive_desc) {
+ if (utils::any_null(primitive_desc, existing_primitive_desc))
+ return invalid_arguments;
+ return safe_ptr_assign<primitive_desc_t>(*primitive_desc,
+ existing_primitive_desc->clone());
+}
+
+status_t mkldnn_primitive_desc_iterator_destroy(
+ primitive_desc_iterator_t *iterator) {
+ if (iterator != nullptr)
+ delete iterator;
+ return success;
+}
+
+status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc,
+ const_c_op_desc_t c_op_desc, const primitive_attr_t *attr,
+ engine_t *engine, const primitive_desc_t *hint_fwd_pd) {
+ const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
+
+ mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd);
+ ++it;
+ if (it == it.end()) return unimplemented;
+
+ return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
new file mode 100644
index 0000000000..4e88ab3aa5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
@@ -0,0 +1,79 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+#ifndef PRIMITIVE_ITERATOR_HPP
+#define PRIMITIVE_ITERATOR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
+ using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
+
+ mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
+ const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
+ : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
+ , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
+ , impl_list_(engine_->get_implementation_list()), last_idx_(0)
+ {
+ while (impl_list_[last_idx_] != nullptr) ++last_idx_;
+ }
+ ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
+
+ bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
+ { return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
+ bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
+ { return !operator==(rhs); }
+
+ mkldnn::impl::primitive_desc_iterator_t end() const
+ { return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
+
+ mkldnn::impl::primitive_desc_iterator_t &operator++() {
+ if (pd_) { delete pd_; pd_ = nullptr; }
+ while (++idx_ != last_idx_) {
+ auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
+ hint_fwd_pd_);
+ if (s == mkldnn::impl::status::success) break;
+ }
+ return *this;
+ }
+
+ mkldnn::impl::primitive_desc_t *operator*() const {
+ if (*this == end() || pd_ == nullptr) return nullptr;
+ return pd_->clone();
+ }
+
+protected:
+ int idx_;
+ mkldnn::impl::engine_t *engine_;
+ mkldnn::impl::primitive_desc_t *pd_;
+ const mkldnn::impl::op_desc_t *op_desc_;
+ const mkldnn::impl::primitive_attr_t attr_;
+ const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
+ const pd_create_f *impl_list_;
+ int last_idx_;
+
+private:
+ mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
+ : idx_(last_idx), engine_(engine), pd_(nullptr)
+ , op_desc_(nullptr), hint_fwd_pd_(nullptr)
+ , impl_list_(nullptr), last_idx_(last_idx) {}
+};
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp
new file mode 100644
index 0000000000..835cd73581
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp
@@ -0,0 +1,59 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc,
+ query_t what, int index, void *result) {
+ if (any_null(primitive_desc, result))
+ return invalid_arguments;
+
+ return primitive_desc->query(what, index, result);
+}
+
+const memory_desc_t *mkldnn_primitive_desc_query_md(
+ const primitive_desc_t *primitive_desc, query_t what, int index) {
+ const memory_desc_t *res_md = nullptr;
+ bool args_ok = true
+ && primitive_desc != nullptr
+ && (what & query::some_md) == query::some_md
+ && what != query::some_md
+ && mkldnn_primitive_desc_query(primitive_desc,
+ what, index, &res_md) == success;
+ return args_ok ? res_md : nullptr;
+}
+
+int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc,
+ query_t what, int index) {
+ int res_s32;
+ bool args_ok = primitive_desc != nullptr
+ && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32)
+ && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32)
+ == success;
+ return args_ok ? res_s32 : 0;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
new file mode 100644
index 0000000000..d11f1a0361
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "reorder_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_reorder_primitive_desc_create(
+ primitive_desc_t **reorder_pd,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md,
+ const primitive_attr_t *attr) {
+ if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md))
+ return invalid_arguments;
+
+ auto s_ek = src_engine->kind();
+ auto d_ek = dst_engine->kind();
+ if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek)))
+ return invalid_arguments;
+
+ auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd);
+ auto s_mdw = memory_desc_wrapper(*src_md);
+ auto d_mdw = memory_desc_wrapper(*dst_md);
+
+ if (!s_mdw.consistent_with(d_mdw))
+ return invalid_arguments;
+
+ auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine;
+
+ const primitive_attr_t dummy_attr;
+ if (attr == NULL)
+ attr = &dummy_attr;
+
+ for (auto r = e->get_reorder_implementation_list(); *r; ++r) {
+ if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md)
+ == success) {
+ (*r_pd)->init_info();
+ (*r_pd)->init_scratchpad_md();
+ return success;
+ }
+ }
+ return unimplemented;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
new file mode 100644
index 0000000000..963cb0f58a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
@@ -0,0 +1,85 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REORDER_PD_HPP
+#define REORDER_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "primitive_attr.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct reorder_pd_t: public primitive_desc_t {
+ reorder_pd_t(engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md)
+ : primitive_desc_t(engine, attr, primitive_kind::reorder)
+ , src_engine_(src_engine)
+ , dst_engine_(dst_engine)
+ , scratchpad_engine_(nullptr)
+ , src_md_(*src_md)
+ , dst_md_(*dst_md)
+ {}
+
+ virtual const op_desc_t *op_desc() const override { return nullptr; }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_FROM)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_TO)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override { return 1; }
+
+ float alpha() const { return attr()->output_scales_.scales_[0]; }
+ float beta() const {
+ const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
+ return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
+ }
+ virtual mkldnn::impl::engine_t *scratchpad_engine() const override
+ { return scratchpad_engine_; }
+
+protected:
+ engine_t *src_engine_;
+ engine_t *dst_engine_;
+ engine_t *scratchpad_engine_;
+
+ memory_desc_t src_md_;
+ memory_desc_t dst_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
new file mode 100644
index 0000000000..36967431a6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
@@ -0,0 +1,400 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu/gemm/os_blas.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::types;
+using namespace mkldnn::impl::utils;
+
+namespace {
+memory_desc_t copy_maybe_null(const memory_desc_t *md) {
+ return md ? *md : zero_md();
+}
+
+rnn_desc_t zero_rnn_desc() {
+ auto rd = rnn_desc_t();
+ rd.src_layer_desc = zero_md();
+ rd.src_iter_desc = zero_md();
+ rd.weights_layer_desc = zero_md();
+ rd.weights_iter_desc = zero_md();
+ rd.bias_desc = zero_md();
+ rd.dst_layer_desc = zero_md();
+ rd.dst_iter_desc = zero_md();
+ rd.diff_src_layer_desc = zero_md();
+ rd.diff_src_iter_desc = zero_md();
+ rd.diff_weights_layer_desc = zero_md();
+ rd.diff_weights_iter_desc = zero_md();
+ rd.diff_bias_desc = zero_md();
+ rd.diff_dst_layer_desc = zero_md();
+ rd.diff_dst_iter_desc = zero_md();
+ return rd;
+}
+}
+
+/* Public C Api */
+
+status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
+ mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
+ unsigned int flags, float alpha, float clipping) {
+ using namespace mkldnn::impl::alg_kind;
+
+ bool args_ok = true
+ && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
+ gru_linear_before_reset)
+ && IMPLICATION(cell_kind == vanilla_rnn,
+ one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
+ if (!args_ok)
+ return invalid_arguments;
+
+ auto rcd = mkldnn_rnn_cell_desc_t();
+
+ rcd.cell_kind = cell_kind;
+ rcd.activation_kind = act_f;
+ rcd.flags = flags;
+ rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
+ rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
+
+ *rnn_cell_desc = rcd;
+
+ return success;
+}
+
+int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
+ switch (rnn_cell_desc->cell_kind) {
+ case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
+ case mkldnn::impl::alg_kind::vanilla_gru: return 3;
+ case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
+ case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
+ default: assert(!"unknown cell kind"); return 0;
+ }
+ return 0;
+}
+
+int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
+ switch (rnn_cell_desc->cell_kind) {
+ case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
+ case mkldnn::impl::alg_kind::vanilla_gru: return 1;
+ case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
+ case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
+ default: assert(!"unknown cell kind"); return 0;
+ }
+ return 0;
+}
+
+status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
+ prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc,
+ const memory_desc_t *dst_iter_desc) {
+ using namespace data_type;
+ data_type_t src_layer_dt = src_layer_desc->data_type;
+ data_type_t dst_layer_dt = dst_layer_desc->data_type;
+ data_type_t weights_iter_dt = weights_iter_desc->data_type;
+ data_type_t weights_layer_dt = weights_layer_desc->data_type;
+
+ bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
+ weights_layer_dt)
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ src_iter_desc->data_type == f32)
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ dst_iter_desc->data_type == f32)
+ && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+#if USE_MKL_PACKED_GEMM
+ bool is_u8u8u8 = src_layer_dt == u8
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ src_iter_desc->data_type == u8)
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ dst_iter_desc->data_type == u8)
+ && one_of(dst_layer_dt, u8, f32)
+ && everyone_is(s8, weights_iter_dt, weights_layer_dt)
+ && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+ bool is_f32u8f32 = src_layer_dt == u8
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ src_iter_desc->data_type == f32)
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ dst_iter_desc->data_type == f32)
+ && one_of(dst_layer_dt, u8, f32)
+ && everyone_is(s8, weights_iter_dt, weights_layer_dt)
+ && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+ bool is_inference = prop_kind == prop_kind::forward_inference;
+ bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
+
+ return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
+ ? success
+ : unimplemented;
+#else
+ return is_f32 ? success : unimplemented;
+#endif
+}
+
+status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
+ rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
+ int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc,
+ const memory_desc_t *dst_iter_desc) {
+ bool args_ok;
+
+ // * algorithm specific
+ args_ok = true
+ && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
+ DIC == SIC);
+ if (!args_ok) return invalid_arguments;
+ int extra_bias =
+ rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
+
+ // * on num layers
+ args_ok = true
+ && L == weights_layer_desc->dims[0]
+ && L == weights_iter_desc->dims[0]
+ && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
+ && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num directions
+ args_ok = true
+ && D == weights_layer_desc->dims[1]
+ && D == weights_iter_desc->dims[1]
+ && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
+ && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num iterations
+ args_ok = true
+ && T == src_layer_desc->dims[0]
+ && T == dst_layer_desc->dims[0];
+ if (!args_ok) return invalid_arguments;
+
+ // * on mb
+ args_ok = true
+ && N == src_layer_desc->dims[1]
+ && N == dst_layer_desc->dims[1]
+ && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num gates
+ args_ok = true
+ && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
+ && G == weights_layer_desc->dims[3]
+ && G == weights_iter_desc->dims[3]
+ && IMPLICATION(!is_zero_md(bias_desc),
+ G + extra_bias == bias_desc->dims[2]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num states
+ args_ok = true
+ && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
+ && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on slc
+ args_ok = true
+ && SLC == weights_layer_desc->dims[2]
+ && SLC == src_layer_desc->dims[2];
+ if (!args_ok) return invalid_arguments;
+
+ // * on sic
+ args_ok = true
+ && SIC == weights_iter_desc->dims[2]
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ SIC == src_iter_desc->dims[4]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on dlc
+ int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
+ args_ok = true
+ && DLC == dlc_multiplier * DIC
+ && DLC == dst_layer_desc->dims[2];
+ if (!args_ok) return invalid_arguments;
+
+ // * on dic
+ args_ok = true
+ && DIC == weights_layer_desc->dims[4]
+ && DIC == weights_iter_desc->dims[4]
+ && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ DIC == dst_iter_desc->dims[4]);
+ if (!args_ok) return invalid_arguments;
+
+ // * unrolling/fusion conditions
+ args_ok = true
+ && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
+ && IMPLICATION(T > 1, SIC == DIC);
+ if (!args_ok) return invalid_arguments;
+
+ return success;
+}
+
+status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
+ prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
+ const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc,
+ const memory_desc_t *dst_iter_desc) {
+ bool args_ok = true && rnn_cell_desc != nullptr
+ && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
+ dst_layer_desc);
+ if (!args_ok) return invalid_arguments;
+
+ //check dimensions consistency
+ int L = weights_layer_desc->dims[0];
+ int T = src_layer_desc->dims[0];
+ int N = src_layer_desc->dims[1];
+ const int D = one_of(direction, mkldnn_unidirectional_left2right,
+ mkldnn_unidirectional_right2left) ?
+ 1 :
+ 2;
+ int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
+ int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
+ int SLC = src_layer_desc->dims[2];
+ int SIC = weights_iter_desc->dims[2];
+ int DLC = dst_layer_desc->dims[2];
+ int DIC = weights_layer_desc->dims[4];
+
+ CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+ G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
+ weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
+ dst_iter_desc));
+
+ CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
+ src_layer_desc, src_iter_desc, weights_layer_desc,
+ weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
+
+ // Create the descriptor
+ mkldnn_rnn_desc_t rd = zero_rnn_desc();
+
+ rd.primitive_kind = primitive_kind::rnn;
+ rd.prop_kind = prop_kind;
+ rd.cell_desc = *rnn_cell_desc;
+ rd.direction = direction;
+ rd.src_layer_desc = copy_maybe_null(src_layer_desc);
+ rd.src_iter_desc = copy_maybe_null(src_iter_desc);
+ rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
+ rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
+ rd.bias_desc = copy_maybe_null(bias_desc);
+ rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
+ rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
+
+ *rnn_desc = rd;
+
+ return success;
+}
+
+status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
+ prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
+ const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
+ const memory_desc_t *diff_src_layer_desc,
+ const memory_desc_t *diff_src_iter_desc,
+ const memory_desc_t *diff_weights_layer_desc,
+ const memory_desc_t *diff_weights_iter_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_layer_desc,
+ const memory_desc_t *diff_dst_iter_desc) {
+ bool args_ok = true
+ && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
+ dst_layer_desc, diff_src_layer_desc,
+ diff_weights_layer_desc, diff_weights_iter_desc,
+ diff_dst_layer_desc);
+ if (!args_ok)
+ return invalid_arguments;
+
+ auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
+ return is_zero_md(a_md) == is_zero_md(b_md);
+ };
+
+ args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
+ && xnor_md(dst_iter_desc, diff_dst_iter_desc)
+ && xnor_md(src_iter_desc, diff_src_iter_desc);
+ if (!args_ok)
+ return invalid_arguments;
+
+ //check dimensions consistency
+ int L = weights_layer_desc->dims[0];
+ int T = src_layer_desc->dims[0];
+ int N = src_layer_desc->dims[1];
+ const int D = one_of(direction, mkldnn_unidirectional_left2right,
+ mkldnn_unidirectional_right2left) ?
+ 1 :
+ 2;
+ int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
+ int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
+ int SLC = src_layer_desc->dims[2];
+ int SIC = weights_iter_desc->dims[2];
+ int DLC = dst_layer_desc->dims[2];
+ int DIC = weights_layer_desc->dims[4];
+
+ status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+ G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
+ weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
+ dst_iter_desc);
+ if (st != success) return st;
+
+ st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+ G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
+ diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
+ diff_dst_layer_desc, diff_dst_iter_desc);
+ if (st != success) return st;
+
+ mkldnn_rnn_desc_t rd = zero_rnn_desc();
+
+ rd.primitive_kind = primitive_kind::rnn;
+ rd.prop_kind = prop_kind;
+ rd.cell_desc = *rnn_cell_desc;
+ rd.direction = direction;
+
+ rd.src_layer_desc = copy_maybe_null(src_layer_desc);
+ rd.src_iter_desc = copy_maybe_null(src_iter_desc);
+ rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
+ rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
+ rd.bias_desc = copy_maybe_null(bias_desc);
+ rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
+ rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
+ rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
+ rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
+ rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
+ rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
+ rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
+ rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
+ rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
+
+ *rnn_desc = rd;
+
+ return success;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
new file mode 100644
index 0000000000..1ee2ba1114
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef RNN_PD_HPP
+#define RNN_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_fwd_pd_t;
+
+struct rnn_pd_t : public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::rnn;
+
+ rnn_pd_t(engine_t *engine,
+ const rnn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const rnn_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , src_layer_md_(desc_.src_layer_desc)
+ , src_iter_md_(desc_.src_iter_desc)
+ , weights_layer_md_(desc_.weights_layer_desc)
+ , weights_iter_md_(desc_.weights_iter_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_layer_md_(desc_.dst_layer_desc)
+ , dst_iter_md_(desc_.dst_iter_desc)
+ , ws_md_()
+ {}
+
+ const rnn_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override {
+ if (index == 0) return &src_layer_md_;
+ if (index == 1 && with_src_iter()) return &src_iter_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_layer_md_;
+ if (index == 1) return &weights_iter_md_;
+ if (index == 2 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *dst_md(int index = 0) const override {
+ if (index == 0) return &dst_layer_md_;
+ if (index == 1 && with_dst_iter()) return &dst_iter_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ /* common pooling aux functions */
+
+ bool is_training() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::backward);
+ }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ dim_t T() const { return desc_.src_layer_desc.dims[0]; }
+ dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
+
+ dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
+ dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
+
+ dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
+
+ dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
+ dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
+ dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
+
+ dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
+
+ bool with_bias() const
+ { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
+
+ bool with_src_iter() const
+ { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
+
+ bool with_dst_iter() const
+ { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
+
+ mkldnn::impl::alg_kind_t cell_kind() const
+ { return desc_.cell_desc.cell_kind; }
+ mkldnn::impl::alg_kind_t activation_kind() const
+ { return desc_.cell_desc.activation_kind; }
+
+ bool is_lbr() const
+ { return cell_kind() == mkldnn_gru_linear_before_reset; }
+
+ mkldnn_rnn_direction_t direction() const { return desc_.direction; }
+
+protected:
+ rnn_desc_t desc_;
+ const rnn_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t src_layer_md_;
+ memory_desc_t src_iter_md_;
+ memory_desc_t weights_layer_md_;
+ memory_desc_t weights_iter_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_layer_md_;
+ memory_desc_t dst_iter_md_;
+
+ memory_desc_t ws_md_;
+};
+
+struct rnn_fwd_pd_t: public rnn_pd_t {
+ typedef rnn_fwd_pd_t base_class;
+ typedef rnn_fwd_pd_t hint_class;
+
+ rnn_fwd_pd_t(engine_t *engine,
+ const rnn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const rnn_fwd_pd_t *hint_fwd_pd)
+ : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC_LAYER)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
+ return arg_usage_t::input;
+
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
+ MKLDNN_ARG_WEIGHTS_ITER))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST_LAYER)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && is_training())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual int n_inputs() const override
+ { return 3 + with_bias() + with_src_iter(); }
+ virtual int n_outputs() const override
+ { return 1 + with_dst_iter() + is_training(); }
+};
+
+struct rnn_bwd_pd_t : public rnn_pd_t {
+ typedef rnn_bwd_pd_t base_class;
+ typedef rnn_fwd_pd_t hint_class;
+
+ rnn_bwd_pd_t(engine_t *engine,
+ const rnn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const rnn_fwd_pd_t *hint_fwd_pd)
+ : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_layer_md_(desc_.diff_src_layer_desc)
+ , diff_src_iter_md_(desc_.diff_src_iter_desc)
+ , diff_weights_layer_md_(desc_.diff_weights_layer_desc)
+ , diff_weights_iter_md_(desc_.diff_weights_iter_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_layer_md_(desc_.diff_dst_layer_desc)
+ , diff_dst_iter_md_(desc_.diff_dst_iter_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
+ MKLDNN_ARG_DIFF_DST_LAYER))
+ return arg_usage_t::input;
+
+ if (with_src_iter()) {
+ if (arg == MKLDNN_ARG_SRC_ITER)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
+ return arg_usage_t::output;
+ }
+
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
+ MKLDNN_ARG_WEIGHTS_ITER))
+ return arg_usage_t::input;
+
+ if (with_bias()) {
+ if (arg == MKLDNN_ARG_BIAS)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS)
+ return arg_usage_t::output;
+ }
+
+ if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
+ && with_dst_iter())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_WORKSPACE)
+ return arg_usage_t::input;
+
+ if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
+ MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
+ MKLDNN_ARG_DIFF_WEIGHTS_ITER))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override {
+ if (index == 0) return &diff_src_layer_md_;
+ if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *diff_weights_md(
+ int index = 0) const override {
+ if (index == 0) return &diff_weights_layer_md_;
+ if (index == 1) return &diff_weights_iter_md_;
+ if (index == 2 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
+ if (index == 0) return &diff_dst_layer_md_;
+ if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override
+ { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
+ virtual int n_outputs() const override
+ { return 3 + with_src_iter() + with_bias(); }
+
+protected:
+ memory_desc_t diff_src_layer_md_;
+ memory_desc_t diff_src_iter_md_;
+ memory_desc_t diff_weights_layer_md_;
+ memory_desc_t diff_weights_iter_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_layer_md_;
+ memory_desc_t diff_dst_iter_md_;
+};
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
new file mode 100644
index 0000000000..6bc14fc72a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
@@ -0,0 +1,112 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "scratchpad.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+/* Allocating memory buffers on a page boundary to reduce TLB/page misses */
+const size_t page_size = 2097152;
+
+/*
+ Implementation of the scratchpad_t interface that is compatible with
+ a concurrent execution
+*/
+struct concurent_scratchpad_t : public scratchpad_t {
+ concurent_scratchpad_t(size_t size) {
+ size_ = size;
+ scratchpad_ = (char *) malloc(size, page_size);
+ assert(scratchpad_ != nullptr);
+ }
+
+ ~concurent_scratchpad_t() {
+ free(scratchpad_);
+ }
+
+ virtual char *get() const {
+ return scratchpad_;
+ }
+
+private:
+ char *scratchpad_;
+ size_t size_;
+};
+
+/*
+ Implementation of the scratchpad_t interface that uses a global
+ scratchpad
+*/
+
+struct global_scratchpad_t : public scratchpad_t {
+ global_scratchpad_t(size_t size) {
+ if (size > size_) {
+ if (scratchpad_ != nullptr) free(scratchpad_);
+ size_ = size;
+ scratchpad_ = (char *) malloc(size, page_size);
+ assert(scratchpad_ != nullptr);
+ }
+ reference_count_++;
+ }
+
+ ~global_scratchpad_t() {
+ reference_count_--;
+ if (reference_count_ == 0) {
+ free(scratchpad_);
+ scratchpad_ = nullptr;
+ size_ = 0;
+ }
+ }
+
+ virtual char *get() const {
+ return scratchpad_;
+ }
+
+private:
+ /*
+ Using thread-local here is unnecessary and even buggy! All threads
+ actually share the same scratchpad, which is created and queried only
+ on the main thread. If the scratchpad is queried on some thread other
+ than the one it was created on (e.g. the application calls the API from
+ multiple threads), thread-local causes a segfault because the scratchpad
+ is uninitialized on the current thread.
+ */
+ /*thread_local*/ static char *scratchpad_;
+ /*thread_local*/ static size_t size_;
+ /*thread_local*/ static unsigned int reference_count_;
+};
+
+/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr;
+/*thread_local*/ size_t global_scratchpad_t::size_ = 0;
+/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0;
+
+
+/*
+ Scratchpad creation routine
+*/
+scratchpad_t *create_scratchpad(size_t size) {
+#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC
+ return new global_scratchpad_t(size);
+#else
+ return new concurent_scratchpad_t(size);
+#endif
+}
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp
new file mode 100644
index 0000000000..f7a246bc99
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef COMMON_SCRATCHPAD_HPP
+#define COMMON_SCRATCHPAD_HPP
+
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct scratchpad_t {
+ virtual ~scratchpad_t() {}
+ virtual char *get() const = 0;
+};
+
+scratchpad_t *create_scratchpad(size_t size);
+
+}
+}
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp
new file mode 100644
index 0000000000..e32e735224
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp
@@ -0,0 +1,72 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind,
+ const memory_desc_t *data_desc, int axis, dim_t group_size) {
+ bool args_ok = true
+ && !any_null(shuffle_desc, data_desc)
+ && one_of(prop_kind, forward_training, forward_inference,
+ backward, backward_data)
+ && axis >= 0 && axis < data_desc->ndims
+ && group_size > 0 && group_size <= data_desc->dims[axis];
+ if (!args_ok) return invalid_arguments;
+
+ auto sd = shuffle_desc_t();
+ sd.primitive_kind = primitive_kind::shuffle;
+ sd.prop_kind = prop_kind;
+ sd.data_desc = *data_desc;
+ sd.axis = axis;
+ sd.group_size = group_size;
+
+ bool consistency = true
+ && sd.data_desc.dims[axis] % sd.group_size == 0;
+ if (!consistency) return invalid_arguments;
+
+ *shuffle_desc = sd;
+ return success;
+}
+}
+
+status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc,
+ prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis,
+ dim_t group_size) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis,
+ group_size);
+}
+
+status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc,
+ const memory_desc_t *diff_data_desc, int axis, dim_t group_size) {
+ return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis,
+ group_size);
+}
+
+// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp
new file mode 100644
index 0000000000..cc5553fe7f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp
@@ -0,0 +1,121 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SHUFFLE_PD_HPP
+#define SHUFFLE_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct shuffle_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::shuffle;
+
+ typedef shuffle_pd_t base_class;
+ typedef shuffle_pd_t hint_class;
+
+ shuffle_pd_t(engine_t *engine,
+ const shuffle_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const shuffle_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ {}
+
+ const shuffle_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::shuffle_d:
+ *(const shuffle_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (is_fwd()) {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+ } else {
+ if (arg == MKLDNN_ARG_DIFF_DST)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+ }
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 && is_fwd() ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 && is_fwd() ? &data_md_ : nullptr; }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 && !is_fwd() ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 && !is_fwd() ? &data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override { return 1; }
+
+ /* shuffle aux functions */
+
+ dim_t MB() const { return data_md()->dims[0]; }
+ dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; }
+ dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_md()->ndims; }
+
+ int axis() const { return desc_.axis; }
+ dim_t group_size() const { return desc_.group_size; }
+ dim_t axis_size() const { return data_md()->dims[axis()]; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ const memory_desc_t *data_md() const { return &data_md_; }
+
+protected:
+ shuffle_desc_t desc_;
+ const shuffle_pd_t *hint_fwd_pd_;
+ memory_desc_t data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp
new file mode 100644
index 0000000000..82848e3d1f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp
@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind,
+ const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) {
+ bool args_ok = true
+ && !any_null(softmax_desc, data_desc)
+ && 0 <= softmax_axis
+ && softmax_axis < data_desc->ndims;
+ if (!args_ok) return invalid_arguments;
+
+ auto sd = softmax_desc_t();
+ sd.primitive_kind = primitive_kind::softmax;
+ sd.prop_kind = prop_kind;
+
+ bool is_bwd = (sd.prop_kind == backward_data);
+ sd.data_desc = *data_desc;
+ sd.diff_desc = is_bwd ? *diff_desc : zero_md();
+ sd.softmax_axis = softmax_axis;
+
+ *softmax_desc = sd;
+ return success;
+}
+}
+
+status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc,
+ prop_kind_t prop_kind, const memory_desc_t *data_desc,
+ int softmax_axis) {
+ if (!one_of(prop_kind, forward_inference, forward_training))
+ return invalid_arguments;
+ return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis);
+}
+
+status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc,
+ const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc,
+ int softmax_axis) {
+ return softmax_desc_init(softmax_desc, prop_kind::backward_data,
+ data_desc, diff_desc, softmax_axis);
+}
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp
new file mode 100644
index 0000000000..8a16ce901c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp
@@ -0,0 +1,161 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SOFTMAX_PD_HPP
+#define SOFTMAX_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct softmax_fwd_pd_t;
+
+struct softmax_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::softmax;
+
+ softmax_pd_t(engine_t *engine,
+ const softmax_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const softmax_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ {}
+
+ const softmax_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::softmax_d:
+ *(const softmax_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common softmax aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_desc().ndims; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ softmax_desc_t desc_;
+ const softmax_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct softmax_fwd_pd_t: public softmax_pd_t {
+ typedef softmax_fwd_pd_t base_class;
+ typedef softmax_fwd_pd_t hint_class;
+
+ softmax_fwd_pd_t(engine_t *engine,
+ const softmax_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const softmax_fwd_pd_t *hint_fwd_pd)
+ : softmax_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+};
+
+struct softmax_bwd_pd_t: public softmax_pd_t {
+ typedef softmax_bwd_pd_t base_class;
+ typedef softmax_fwd_pd_t hint_class;
+
+ softmax_bwd_pd_t(engine_t *engine,
+ const softmax_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const softmax_fwd_pd_t *hint_fwd_pd)
+ : softmax_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::input;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+
+ virtual int n_inputs() const override
+ { return 2 + (workspace_md() != nullptr); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp
new file mode 100644
index 0000000000..00af8935c0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp
@@ -0,0 +1,46 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "stream.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+/* API */
+
+status_t mkldnn_stream_create(stream_t **stream, engine_t *engine,
+ unsigned flags) {
+ bool args_ok = true
+ && !utils::any_null(stream, engine)
+ && flags == stream_flags::default_flags;
+ if (!args_ok)
+ return invalid_arguments;
+
+ return safe_ptr_assign<stream_t>(*stream, new stream_t(engine, flags));
+}
+
+status_t mkldnn_stream_destroy(stream_t *stream) {
+ delete stream;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp
new file mode 100644
index 0000000000..f010e5f6ed
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp
@@ -0,0 +1,44 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef STREAM_HPP
+#define STREAM_HPP
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+
+struct mkldnn_stream: public mkldnn::impl::c_compatible {
+ mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags)
+ : engine_(engine), flags_(flags) {}
+ virtual ~mkldnn_stream() {}
+
+ /** returns stream's engine */
+ mkldnn::impl::engine_t *engine() const { return engine_; }
+
+ /** returns stream's kind */
+ unsigned flags() const { return flags_; }
+
+protected:
+ mkldnn::impl::engine_t *engine_;
+ unsigned flags_;
+};
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp
new file mode 100644
index 0000000000..365663c0f8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp
@@ -0,0 +1,79 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "sum_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd,
+ const memory_desc_t *dst_md, int n, const float *scales,
+ const memory_desc_t *src_mds, const primitive_attr_t *attr,
+ engine_t *engine) {
+ bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0;
+ if (!args_ok) return invalid_arguments;
+
+ const primitive_attr_t dummy_attr;
+ if (attr == NULL)
+ attr = &dummy_attr;
+
+ const int ndims = src_mds[0].ndims;
+ const dims_t &dims = src_mds[0].dims;
+ const data_type_t dt = src_mds[0].data_type;
+
+ for (int i = 1; i < n; ++i) {
+ if (src_mds[i].ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (src_mds[i].dims[d] != dims[d])
+ return invalid_arguments;
+ }
+ if (src_mds[i].data_type != dt) return invalid_arguments;
+ }
+
+ memory_desc_t dummy_dst_md;
+ if (dst_md) {
+ if (dst_md->ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (dst_md->dims[d] != dims[d])
+ return invalid_arguments;
+ }
+ } else {
+ dummy_dst_md = src_mds[0];
+ dummy_dst_md.format_kind = format_kind::any;
+ dst_md = &dummy_dst_md;
+ }
+
+ auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd);
+
+ for (auto s = engine->get_sum_implementation_list(); *s; ++s) {
+ if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) {
+ (*s_pd)->init_info();
+ (*s_pd)->init_scratchpad_md();
+ return success;
+ }
+ }
+ return unimplemented;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp
new file mode 100644
index 0000000000..80254667df
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp
@@ -0,0 +1,143 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SUM_PD_HPP
+#define SUM_PD_HPP
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct sum_pd_t: public primitive_desc_t {
+ sum_pd_t(engine_t *engine, const primitive_attr_t *attr,
+ const memory_desc_t *dst_md, int n, const float *scales,
+ const memory_desc_t *src_mds)
+ : primitive_desc_t(engine, attr, primitive_kind::sum)
+ , n_(n), dst_md_(*dst_md)
+ {
+ scales_.reserve(n_);
+ for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]);
+ src_mds_.reserve(n_);
+ for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
+ }
+
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg >= MKLDNN_ARG_MULTIPLE_SRC
+ && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index < n_inputs() ? &src_mds_[index] : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return n_; }
+ virtual int n_outputs() const override { return 1; }
+
+ const float *scales() const { return &scales_[0]; }
+
+protected:
+ int n_;
+ nstl::vector<float> scales_;
+ memory_desc_t dst_md_;
+ nstl::vector<memory_desc_t> src_mds_;
+
+protected:
+ /* inits dst_md_ in simple cases. The call may fail. */
+ status_t init() {
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(&src_mds_[i]);
+ if (!src_d.is_blocking_desc() || src_d.is_additional_buffer())
+ return status::unimplemented;
+ }
+ bool ok = true
+ && set_default_params() == status::success
+ && attr()->has_default_values();
+ return ok ? status::success : status::unimplemented;
+ }
+
+ status_t set_default_params() {
+ if (dst_md_.format_kind != format_kind::any)
+ return status::success;
+
+ /* The stupidest ever heuristics (but not the same as we had before):
+ * - Pick the first non-plain format;
+ * - If all formats are plain, pick the format of the first input
+ */
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(src_mds_[i]);
+ if (!src_d.is_plain() && src_d.is_blocking_desc()) {
+ return memory_desc_init_by_blocking_desc(dst_md_,
+ src_d.blocking_desc());
+ }
+ }
+
+ if (src_mds_[0].format_kind != format_kind::blocked)
+ return status::unimplemented;
+
+ dst_md_ = src_mds_[0];
+
+ return status::success;
+ }
+};
+
+#define DECLARE_SUM_PD_t(impl_name, ...) \
+ static status_t create(sum_pd_t **sum_pd, \
+ engine_t *engine, const primitive_attr_t *attr, \
+ const memory_desc_t *dst_md, int n, const float *scales, \
+ const memory_desc_t *src_mds) { \
+ using namespace status; \
+ auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \
+ if (_pd == nullptr) return out_of_memory; \
+ if (_pd->init() != success) { delete _pd; return unimplemented; } \
+ return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd); \
+ } \
+ virtual status_t create_primitive(primitive_t **p) const override { \
+ double ms = get_msec(); \
+ auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+ ms = get_msec() - ms; \
+ if (mkldnn_verbose()->level >= 2) { \
+ printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+ fflush(0); \
+ } \
+ return ret; \
+ } \
+ virtual pd_t *clone() const override { return new pd_t(*this); } \
+ virtual const char *name() const override { return impl_name; } \
+
+#define DECLARE_SUM_PD_T(impl_name, ...) \
+ DECLARE_SUM_PD_t(impl_name, __VA_ARGS__)
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp
new file mode 100644
index 0000000000..a408f45980
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp
@@ -0,0 +1,200 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TAG_TRAITS_HPP
+#define TAG_TRAITS_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+enum class block_dim_t {
+ _,
+ _A, _B,
+ _AB, _BC,
+};
+
+enum class inner_blk_t {
+ _,
+ _4a, _4b,
+ _8a, _8b,
+ _16a, _16b,
+
+ _4b4a, _4b4c, _4c4b,
+ _8a8b, _8b8a, _8b8c, _8c8b,
+ _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b,
+
+ _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c,
+};
+
+/** returns the offset within the block for weights blocked over oc and ic */
+template <inner_blk_t f>
+constexpr int AB_or_BC_blk_off(int x0, int x1) {
+ using ib = inner_blk_t;
+ static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b,
+ ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b,
+ ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c,
+ ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
+ ib::_4c16b4c, ib::_8c16b2c),
+ "unexpected inner_blk format");
+ return false ? 0
+ : (f == ib::_4b4c) ? 4 * x0 + x1
+ : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0
+ : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1
+ : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0
+ : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1
+ : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0
+ : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1
+ : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2
+ : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
+ : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
+ : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4
+ : INT_MIN;
+}
+
+template <inner_blk_t b> struct inner_blk_traits {
+ using ib = inner_blk_t;
+};
+
+template <format_tag_t> struct tag_traits {
+ // block_dim_t block_dims;
+ // inner_blk_t inner_blks;
+ // int ndims;
+};
+
+#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \
+template <> struct tag_traits<format_tag::_tag> { \
+ static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \
+ static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \
+ static constexpr int ndims = _ndims; \
+}
+
+DECL_TRAITS(a, _, _, 1);
+DECL_TRAITS(ab, _, _, 2);
+DECL_TRAITS(abc, _, _, 3);
+DECL_TRAITS(abcd, _, _, 4);
+DECL_TRAITS(abcde, _, _, 5);
+DECL_TRAITS(abcdef, _, _, 6);
+DECL_TRAITS(abdec, _, _, 5);
+DECL_TRAITS(acb, _, _, 3);
+DECL_TRAITS(acbde, _, _, 5);
+DECL_TRAITS(acdb, _, _, 4);
+DECL_TRAITS(acdeb, _, _, 5);
+DECL_TRAITS(ba, _, _, 2);
+DECL_TRAITS(bac, _, _, 3);
+DECL_TRAITS(bacd, _, _, 4);
+DECL_TRAITS(bcda, _, _, 4);
+DECL_TRAITS(cba, _, _, 3);
+DECL_TRAITS(cdba, _, _, 4);
+DECL_TRAITS(cdeba, _, _, 5);
+DECL_TRAITS(decab, _, _, 5);
+
+DECL_TRAITS(Abc4a, _A, _4a, 3);
+DECL_TRAITS(aBc4b, _B, _4b, 3);
+DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3);
+DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3);
+DECL_TRAITS(Abcd4a, _A, _4a, 4);
+DECL_TRAITS(aBcd4b, _B, _4b, 4);
+DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4);
+DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4);
+DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4);
+DECL_TRAITS(Abcde4a, _A, _4a, 5);
+DECL_TRAITS(aBcde4b, _B, _4b, 5);
+DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5);
+DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5);
+DECL_TRAITS(aBcdef4b, _B, _4b, 6);
+DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6);
+DECL_TRAITS(aBdc4b, _B, _4b, 4);
+DECL_TRAITS(aBdec4b, _B, _4b, 5);
+DECL_TRAITS(aBdefc4b, _B, _4b, 6);
+DECL_TRAITS(Acb4a, _A, _4a, 3);
+DECL_TRAITS(Acdb4a, _A, _4a, 4);
+DECL_TRAITS(Acdeb4a, _A, _4a, 5);
+
+DECL_TRAITS(Abc16a, _A, _16a, 3);
+DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3);
+DECL_TRAITS(aBc16b, _B, _16b, 3);
+DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3);
+DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3);
+DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3);
+DECL_TRAITS(aBc8b, _B, _8b, 3);
+DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3);
+DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3);
+DECL_TRAITS(Abcd16a, _A, _16a, 4);
+DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4);
+DECL_TRAITS(aBcd16b, _B, _16b, 4);
+DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4);
+DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4);
+DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4);
+DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4);
+DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4);
+DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4);
+DECL_TRAITS(aBcd8b, _B, _8b, 4);
+DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4);
+DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4);
+DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4);
+DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4);
+DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4);
+DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4);
+DECL_TRAITS(Abcde16a, _A, _16a, 5);
+DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5);
+DECL_TRAITS(aBcde16b, _B, _16b, 5);
+DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5);
+DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5);
+DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5);
+DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5);
+DECL_TRAITS(Abcde8a, _A, _8a, 5);
+DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5);
+DECL_TRAITS(aBcde8b, _B, _8b, 5);
+DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5);
+DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5);
+DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5);
+DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5);
+DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5);
+DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5);
+DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5);
+DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5);
+DECL_TRAITS(aBcdef16b, _B, _16b, 6);
+DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6);
+DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6);
+DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6);
+DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6);
+DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6);
+DECL_TRAITS(aBdc16b, _B, _16b, 4);
+DECL_TRAITS(aBdc8b, _B, _8b, 4);
+DECL_TRAITS(aBdec16b, _B, _16b, 5);
+DECL_TRAITS(aBdec8b, _B, _8b, 5);
+DECL_TRAITS(aBdefc16b, _B, _16b, 6);
+DECL_TRAITS(aBdefc8b, _B, _8b, 6);
+DECL_TRAITS(Acb16a, _A, _16a, 3);
+DECL_TRAITS(Acb8a, _A, _8a, 3);
+DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4);
+DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5);
+DECL_TRAITS(Acdb16a, _A, _16a, 4);
+DECL_TRAITS(Acdb8a, _A, _8a, 4);
+DECL_TRAITS(Acdeb16a, _A, _16a, 5);
+DECL_TRAITS(Acdeb8a, _A, _8a, 5);
+DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3);
+DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4);
+
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp
new file mode 100644
index 0000000000..4f06368738
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp
@@ -0,0 +1,348 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TYPE_HELPERS_HPP
+#define TYPE_HELPERS_HPP
+
+#include <assert.h>
+#include <math.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "mkldnn_traits.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "math_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+template <typename T>
+status_t safe_ptr_assign(T * &lhs, T* rhs) {
+ if (rhs == nullptr) return status::out_of_memory;
+ lhs = rhs;
+ return status::success;
+}
+
+template <typename T, typename U> struct is_subset
+{ static constexpr bool value = false; };
+template <typename T> struct is_subset<T, T>
+{ static constexpr bool value = true; };
+template <typename T> struct is_subset<T,
+ typename utils::enable_if<nstl::is_integral<T>::value, float>::type>
+{ static constexpr bool value = true; };
+#define ISSPEC(t1, t2) template <> \
+ struct is_subset<t1, t2> { static constexpr bool value = true; }
+ISSPEC(int16_t, int32_t);
+ISSPEC(int8_t, int32_t);
+ISSPEC(uint8_t, int32_t);
+ISSPEC(int8_t, int16_t);
+ISSPEC(uint8_t, int16_t);
+#undef ISSPEC
+
+inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs);
+
+namespace types {
+
+inline size_t data_type_size(data_type_t data_type) {
+ using namespace data_type;
+ switch (data_type) {
+ case f32: return sizeof(prec_traits<f32>::type);
+ case s32: return sizeof(prec_traits<s32>::type);
+ case s8: return sizeof(prec_traits<s8>::type);
+ case u8: return sizeof(prec_traits<u8>::type);
+ case data_type::undef:
+ default: assert(!"unknown data_type");
+ }
+ return 0; /* not supposed to be reachable */
+}
+
+inline format_kind_t format_tag_to_kind(format_tag_t tag) {
+ switch (tag) {
+ case format_tag::undef: return format_kind::undef;
+ case format_tag::any: return format_kind::any;
+ case format_tag::last: return format_kind::undef;
+ default: return format_kind::blocked;
+ }
+
+ assert(!"unreachable");
+ return format_kind::undef;
+}
+
+inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs,
+ const memory_extra_desc_t &rhs) {
+ return true
+ && lhs.flags == rhs.flags
+ && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8,
+ lhs.compensation_mask == rhs.compensation_mask)
+ && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust,
+ lhs.scale_adjust == rhs.scale_adjust);
+}
+
+inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
+ const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) {
+ using mkldnn::impl::utils::array_cmp;
+ return true
+ && lhs.inner_nblks == rhs.inner_nblks
+ && array_cmp(lhs.strides, rhs.strides, ndims)
+ && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks)
+ && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks);
+}
+
+inline bool wino_desc_is_equal(const wino_desc_t &lhs,
+ const wino_desc_t &rhs) {
+ return lhs.wino_format == rhs.wino_format
+ && lhs.alpha == rhs.alpha
+ && lhs.ic == rhs.ic
+ && lhs.oc == rhs.oc
+ && lhs.ic_block == rhs.ic_block
+ && lhs.oc_block == rhs.oc_block
+ && lhs.ic2_block == rhs.ic2_block
+ && lhs.oc2_block == rhs.oc2_block
+ && lhs.r == rhs.r;
+}
+
+inline bool rnn_packed_desc_is_equal(
+ const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) {
+ bool ok = true
+ && lhs.format == rhs.format
+ && lhs.n_parts == rhs.n_parts
+ && lhs.offset_compensation == rhs.offset_compensation
+ && lhs.size == rhs.size
+ && lhs.n == rhs.n;
+ if (!ok)
+ return false;
+
+ for (int i = 0; i < rhs.n_parts; i++)
+ ok = ok && lhs.parts[i] == rhs.parts[i];
+ for (int i = 0; i < rhs.n_parts; i++)
+ ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
+ return ok;
+}
+
+inline memory_desc_t zero_md() {
+ auto zero = memory_desc_t();
+ return zero;
+}
+
+inline bool is_zero_md(const memory_desc_t *md) {
+ return md == nullptr || *md == zero_md();
+}
+
+inline data_type_t default_accum_data_type(data_type_t src_dt,
+ data_type_t dst_dt) {
+ using namespace utils;
+ using namespace data_type;
+
+ if (one_of(f32, src_dt, dst_dt)) return f32;
+ if (one_of(s32, src_dt, dst_dt)) return s32;
+
+ if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
+
+ assert(!"unimplemented use-case: no default parameters available");
+ return dst_dt;
+}
+
+inline data_type_t default_accum_data_type(data_type_t src_dt,
+ data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) {
+ using namespace utils;
+ using namespace data_type;
+ using namespace prop_kind;
+
+ /* prop_kind doesn't matter */
+ if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32;
+
+ if (one_of(prop_kind, forward_training, forward_inference)) {
+ if ((src_dt == u8 || src_dt == s8)
+ && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
+ return s32;
+ } else if (prop_kind == backward_data) {
+ if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
+ one_of(dst_dt, s8, u8))
+ return s32;
+ }
+
+ assert(!"unimplemented use-case: no default parameters available");
+ return dst_dt;
+}
+
+}
+
+inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
+ using namespace mkldnn::impl::utils;
+ bool base_equal = true
+ && lhs.ndims == rhs.ndims
+ && array_cmp(lhs.dims, rhs.dims, lhs.ndims)
+ && lhs.data_type == rhs.data_type
+ && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims)
+ && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims)
+ && lhs.offset0 == rhs.offset0
+ && lhs.format_kind == rhs.format_kind;
+ if (!base_equal) return false;
+ if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false;
+ if (lhs.format_kind == format_kind::blocked)
+ return types::blocking_desc_is_equal(lhs.format_desc.blocking,
+ rhs.format_desc.blocking, lhs.ndims);
+ else if (lhs.format_kind == format_kind::wino)
+ return types::wino_desc_is_equal(lhs.format_desc.wino_desc,
+ rhs.format_desc.wino_desc);
+ else if (lhs.format_kind == format_kind::rnn_packed)
+ return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc,
+ rhs.format_desc.rnn_packed_desc);
+ return true;
+}
+
+inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) {
+ return !operator==(lhs, rhs);
+}
+
+inline status_t memory_desc_init_by_strides(memory_desc_t &md,
+ const dims_t strides) {
+ return mkldnn_memory_desc_init_by_strides(
+ &md, md.ndims, md.dims, md.data_type, strides);
+}
+
+inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag,
+ const dims_t strides = nullptr) {
+ status_t status = mkldnn_memory_desc_init_by_tag(
+ &md, md.ndims, md.dims, md.data_type, tag);
+ if (status != status::success || strides == nullptr)
+ return status;
+
+ /* TODO: add consistency check */
+
+ for (int d = 0; d < md.ndims; ++d)
+ md.format_desc.blocking.strides[d] = strides[d];
+
+ return status::success;
+}
+
+/** inits memory descriptor based on logical dimensions kept in @p md, and the
+ * blocking structure @p blk.
+ *
+ * @note blk.strides represent the order only (from smaller to bigger)
+ *
+ * TODO: move md related functions to one single place
+ */
+inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md,
+ const blocking_desc_t &blk) {
+ dims_t blocks = {0};
+ utils::array_set(blocks, 1, md.ndims);
+ dim_t block_size = 1;
+ for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
+ blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk];
+ block_size *= blk.inner_blks[iblk];
+ }
+
+ for (int d = 0; d < md.ndims; ++d) {
+ md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
+ md.padded_offsets[d] = 0;
+ }
+ md.offset0 = 0;
+
+ md.format_kind = format_kind::blocked;
+ auto &mblk = md.format_desc.blocking;
+ mblk = blk;
+
+ const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy
+ utils::array_copy(mblk.strides, blk.strides, ndims);
+
+ int perm[MKLDNN_MAX_NDIMS];
+ for (int d = 0; d < ndims; ++d) perm[d] = d;
+
+ utils::simultaneous_sort(mblk.strides, perm, ndims,
+ [](stride_t a, stride_t b) { return b - a; });
+
+ dim_t stride = block_size;
+ for (int _d = ndims - 1; _d >= 0; --_d) {
+ const int d = perm[_d];
+ md.format_desc.blocking.strides[d] = stride;
+ stride *= md.padded_dims[d] / blocks[d];
+ }
+
+ md.extra = utils::zero<memory_extra_desc_t>();
+
+ return status::success;
+}
+
+/** returns true if memory desc @p md corresponds to the given format tag and
+ * strides.
+ * If strides are not passed (or passed as nullptr) the dense structure is
+ * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns).
+ * Strides might contain `0` value, indicating the stride must match the one
+ * that mkldnn_memory_desc_init_by_tag() returns.
+ * Strides might contain `-1` values, that would be ignored during the
+ * comparison. For instance, this can be used if a stride along minibatch
+ * doesn't matter. */
+inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag,
+ const dims_t strides = nullptr) {
+ if (md.format_kind != types::format_tag_to_kind(tag))
+ return false;
+
+ memory_desc_t md_gold;
+ status_t status = mkldnn_memory_desc_init_by_tag(
+ &md_gold, md.ndims, md.dims, md.data_type, tag);
+ if (status != status::success) return false;
+
+ if (md.format_kind != format_kind::blocked)
+ return false; // unimplemented yet
+
+ const auto &blk = md.format_desc.blocking;
+ const auto &blk_gold = md_gold.format_desc.blocking;
+
+ using utils::array_cmp;
+ bool same_blocks = true
+ && blk.inner_nblks == blk_gold.inner_nblks
+ && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks)
+ && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks);
+
+ if (!same_blocks)
+ return false;
+
+ if (strides == nullptr)
+ return array_cmp(blk.strides, blk_gold.strides, md.ndims);
+
+ for (int d = 0; d < md.ndims; ++d) {
+ dim_t stride = strides[d];
+ if (stride == -1) continue;
+ if (stride == 0) stride = blk_gold.strides[d];
+ if (blk.strides[d] != stride) return false;
+ }
+
+ return true;
+}
+
+/** returns matching tag (or undef if match is not found)
+ * XXX: This is a workaround that eventually should go away! */
+template <typename... Tags>
+format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md,
+ Tags ...tags) {
+ for (const auto tag: {tags...}) {
+ if (memory_desc_matches_tag(md, tag))
+ return tag;
+ }
+ return format_tag::undef;
+}
+
+}
+}
+
+#include "memory_desc_wrapper.hpp"
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp
new file mode 100644
index 0000000000..d23f4682dc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp
@@ -0,0 +1,135 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <string.h>
+#ifdef _WIN32
+#include <malloc.h>
+#include <windows.h>
+#endif
+#include <limits.h>
+#include <stdlib.h>
+#include <stdio.h>
+
+#include "mkldnn.h"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+int getenv(const char *name, char *buffer, int buffer_size) {
+ if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0))
+ return INT_MIN;
+
+ int result = 0;
+ int term_zero_idx = 0;
+ size_t value_length = 0;
+
+#ifdef _WIN32
+ value_length = GetEnvironmentVariable(name, buffer, buffer_size);
+#else
+ const char *value = ::getenv(name);
+ value_length = value == NULL ? 0 : strlen(value);
+#endif
+
+ if (value_length > INT_MAX)
+ result = INT_MIN;
+ else {
+ int int_value_length = (int)value_length;
+ if (int_value_length >= buffer_size) {
+ result = -int_value_length;
+ } else {
+ term_zero_idx = int_value_length;
+ result = int_value_length;
+#ifndef _WIN32
+ strncpy(buffer, value, value_length);
+#endif
+ }
+ }
+
+ if (buffer != NULL)
+ buffer[term_zero_idx] = '\0';
+ return result;
+}
+
+int getenv_int(const char *name, int default_value)
+{
+ int value = default_value;
+ // # of digits in the longest 32-bit signed int + sign + terminating null
+ const int len = 12;
+ char value_str[len];
+ if (getenv(name, value_str, len) > 0)
+ value = atoi(value_str);
+ return value;
+}
+
+FILE *fopen(const char *filename, const char *mode) {
+#ifdef _WIN32
+ FILE *fp = NULL;
+ return ::fopen_s(&fp, filename, mode) ? NULL : fp;
+#else
+ return ::fopen(filename, mode);
+#endif
+}
+
+void *malloc(size_t size, int alignment) {
+ void *ptr;
+
+#ifdef _WIN32
+ ptr = _aligned_malloc(size, alignment);
+ int rc = ptr ? 0 : -1;
+#else
+ int rc = ::posix_memalign(&ptr, alignment, size);
+#endif
+
+ return (rc == 0) ? ptr : 0;
+}
+
+void free(void *p) {
+#ifdef _WIN32
+ _aligned_free(p);
+#else
+ ::free(p);
+#endif
+}
+
+// Atomic operations
+int32_t fetch_and_add(int32_t *dst, int32_t val) {
+#ifdef _WIN32
+ return InterlockedExchangeAdd(reinterpret_cast<long*>(dst), val);
+#else
+ return __sync_fetch_and_add(dst, val);
+#endif
+}
+
+static int jit_dump_flag = 0;
+static bool jit_dump_flag_initialized = false;
+bool jit_dump_enabled() {
+ if (!jit_dump_flag_initialized) {
+ jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP");
+ jit_dump_flag_initialized = true;
+ }
+ return jit_dump_flag != 0;
+}
+
+}
+}
+
+mkldnn_status_t mkldnn_set_jit_dump(int enabled) {
+ using namespace mkldnn::impl::status;
+ mkldnn::impl::jit_dump_flag = enabled;
+ mkldnn::impl::jit_dump_flag_initialized = true;
+ return success;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp
new file mode 100644
index 0000000000..d5a8ec5139
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp
@@ -0,0 +1,370 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef UTILS_HPP
+#define UTILS_HPP
+
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+#include <stdint.h>
+
+#if defined(__x86_64__) || defined(_M_X64)
+#define MKLDNN_X86_64
+#endif
+
+#define MSAN_ENABLED 0
+#if defined(__has_feature)
+#if __has_feature(memory_sanitizer)
+#undef MSAN_ENABLED
+#define MSAN_ENABLED 1
+#include <sanitizer/msan_interface.h>
+#endif
+#endif
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+// Sanity check for 64 bits
+static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only");
+
+#define CHECK(f) do { \
+ status_t status = f; \
+ if (status != status::success) \
+ return status; \
+} while (0)
+
+#define IMPLICATION(cause, effect) (!(cause) || !!(effect))
+
+namespace utils {
+
+/* a bunch of std:: analogues to be compliant with any msvs version
+ *
+ * Rationale: msvs c++ (and even some c) headers contain special pragma that
+ * injects msvs-version check into object files in order to abi-mismatches
+ * during the static linking. This makes sense if e.g. std:: objects are passed
+ * through between application and library, which is not the case for mkl-dnn
+ * (since there is no any c++-rt dependent stuff, ideally...). */
+
+/* SFINAE helper -- analogue to std::enable_if */
+template<bool expr, class T = void> struct enable_if {};
+template<class T> struct enable_if<true, T> { typedef T type; };
+
+/* analogue std::conditional */
+template <bool, typename, typename> struct conditional {};
+template <typename T, typename F> struct conditional<true, T, F>
+{ typedef T type; };
+template <typename T, typename F> struct conditional<false, T, F>
+{ typedef F type; };
+
+template <bool, typename, bool, typename, typename> struct conditional3 {};
+template <typename T, typename FT, typename FF>
+struct conditional3<true, T, false, FT, FF> { typedef T type; };
+template <typename T, typename FT, typename FF>
+struct conditional3<false, T, true, FT, FF> { typedef FT type; };
+template <typename T, typename FT, typename FF>
+struct conditional3<false, T, false, FT, FF> { typedef FF type; };
+
+template <bool, typename U, U, U> struct conditional_v {};
+template <typename U, U t, U f> struct conditional_v<true, U, t, f>
+{ static constexpr U value = t; };
+template <typename U, U t, U f> struct conditional_v<false, U, t, f>
+{ static constexpr U value = f; };
+
+template <typename T> struct remove_reference { typedef T type; };
+template <typename T> struct remove_reference<T&> { typedef T type; };
+template <typename T> struct remove_reference<T&&> { typedef T type; };
+
+template <typename T>
+inline T&& forward(typename utils::remove_reference<T>::type &t)
+{ return static_cast<T&&>(t); }
+template <typename T>
+inline T&& forward(typename utils::remove_reference<T>::type &&t)
+{ return static_cast<T&&>(t); }
+
+template <typename T>
+inline typename remove_reference<T>::type zero()
+{ auto zero = typename remove_reference<T>::type(); return zero; }
+
+template <typename T, typename P>
+inline bool everyone_is(T val, P item) { return val == item; }
+template <typename T, typename P, typename... Args>
+inline bool everyone_is(T val, P item, Args... item_others) {
+ return val == item && everyone_is(val, item_others...);
+}
+
+template <typename T, typename P>
+constexpr bool one_of(T val, P item) { return val == item; }
+template <typename T, typename P, typename... Args>
+constexpr bool one_of(T val, P item, Args... item_others) {
+ return val == item || one_of(val, item_others...);
+}
+
+template <typename... Args>
+inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); }
+
+template<typename T>
+inline void array_copy(T *dst, const T *src, size_t size) {
+ for (size_t i = 0; i < size; ++i) dst[i] = src[i];
+}
+template<typename T>
+inline bool array_cmp(const T *a1, const T *a2, size_t size) {
+ for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false;
+ return true;
+}
+template<typename T, typename U>
+inline void array_set(T *arr, const U& val, size_t size) {
+ for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
+}
+
+namespace product_impl {
+template<size_t> struct int2type{};
+
+template <typename T>
+constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; }
+
+template <typename T, size_t num>
+inline T product_impl(const T *arr, int2type<num>) {
+ return arr[0]*product_impl(arr+1, int2type<num-1>()); }
+}
+
+template <size_t num, typename T>
+inline T array_product(const T *arr) {
+ return product_impl::product_impl(arr, product_impl::int2type<num-1>());
+}
+
+template<typename T, typename R = T>
+inline R array_product(const T *arr, size_t size) {
+ R prod = 1;
+ for (size_t i = 0; i < size; ++i) prod *= arr[i];
+ return prod;
+}
+
+/** sorts an array of values using @p comparator. While sorting the array
+ * of value, the function permutes an array of @p keys accordingly.
+ *
+ * @note The arrays of @p keys can be omitted. In this case the function
+ * sorts the array of @vals only.
+ */
+template <typename T, typename U, typename F>
+inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) {
+ if (size == 0) return;
+
+ for (size_t i = 0; i < size - 1; ++i) {
+ bool swapped = false;
+
+ for (size_t j = 0; j < size - i - 1; j++) {
+ if (comparator(vals[j], vals[j + 1]) > 0) {
+ nstl::swap(vals[j], vals[j + 1]);
+ if (keys) nstl::swap(keys[j], keys[j + 1]);
+ swapped = true;
+ }
+ }
+
+ if (swapped == false) break;
+ }
+}
+
+template <typename T, typename U>
+inline typename remove_reference<T>::type div_up(const T a, const U b) {
+ assert(b);
+ return (a + b - 1) / b;
+}
+
+template <typename T, typename U>
+inline typename remove_reference<T>::type rnd_up(const T a, const U b) {
+ return div_up(a, b) * b;
+}
+
+template <typename T, typename U>
+inline typename remove_reference<T>::type rnd_dn(const T a, const U b) {
+ return (a / b) * b;
+}
+
+template <typename T> T *align_ptr(T *ptr, uintptr_t alignment)
+{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); }
+
+template <typename T, typename U, typename V>
+inline U this_block_size(const T offset, const U max, const V block_size) {
+ assert(offset < max);
+ // TODO (Roma): can't use nstl::max() due to circular dependency... we
+ // need to fix this
+ const T block_boundary = offset + block_size;
+ if (block_boundary > max)
+ return max - offset;
+ else
+ return block_size;
+}
+
+template<typename T>
+inline T nd_iterator_init(T start) { return start; }
+template<typename T, typename U, typename W, typename... Args>
+inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) {
+ start = nd_iterator_init(start, utils::forward<Args>(tuple)...);
+ x = start % X;
+ return start / X;
+}
+
+inline bool nd_iterator_step() { return true; }
+template<typename U, typename W, typename... Args>
+inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) {
+ if (nd_iterator_step(utils::forward<Args>(tuple)...) ) {
+ x = (x + 1) % X;
+ return x == 0;
+ }
+ return false;
+}
+
+template<typename U, typename W, typename Y>
+inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X)
+{
+ U max_jump = end - cur;
+ U dim_jump = X - x;
+ if (dim_jump <= max_jump) {
+ x = 0;
+ cur += dim_jump;
+ return true;
+ } else {
+ cur += max_jump;
+ x += max_jump;
+ return false;
+ }
+}
+template<typename U, typename W, typename Y, typename... Args>
+inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X,
+ Args &&... tuple)
+{
+ if (nd_iterator_jump(cur, end, utils::forward<Args>(tuple)...)) {
+ x = (x + 1) % X;
+ return x == 0;
+ }
+ return false;
+}
+
+template <typename T>
+inline T pick(size_t i, const T &x0) { return x0; }
+template <typename T, typename ...Args>
+inline T pick(size_t i, const T &x0, Args &&... args) {
+ return i == 0 ? x0 : pick(i - 1, utils::forward<Args>(args)...);
+}
+
+template <typename T>
+T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference,
+ const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) {
+ switch (prop_kind) {
+ case prop_kind::forward_inference: return val_fwd_inference;
+ case prop_kind::forward_training: return val_fwd_training;
+ case prop_kind::backward_data: return val_bwd_d;
+ case prop_kind::backward_weights: return val_bwd_w;
+ default: assert(!"unsupported prop_kind");
+ }
+ return T();
+}
+
+template <typename T>
+T pick_by_prop_kind(prop_kind_t prop_kind,
+ const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w)
+{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); }
+
+template <typename Telem, size_t Tdims>
+struct array_offset_calculator {
+ template <typename... Targs>
+ array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... }
+ {
+ _base_ptr = base;
+ }
+ template <typename... Targs>
+ inline Telem &operator()(Targs... Fargs)
+ {
+ return *(_base_ptr + _offset(1, Fargs...));
+ }
+
+private:
+ template <typename... Targs>
+ inline size_t _offset(size_t const dimension, size_t element)
+ {
+ return element;
+ }
+
+ template <typename... Targs>
+ inline size_t _offset(size_t const dimension, size_t theta, size_t element)
+ {
+ return element + (_dims[dimension] * theta);
+ }
+
+ template <typename... Targs>
+ inline size_t _offset(size_t const dimension, size_t theta, size_t element,
+ Targs... Fargs)
+ {
+ size_t t_prime = element + (_dims[dimension] * theta);
+ return _offset(dimension + 1, t_prime, Fargs...);
+ }
+
+ Telem *_base_ptr;
+ const int _dims[Tdims];
+};
+
+}
+
+int32_t fetch_and_add(int32_t *dst, int32_t val);
+inline void yield_thread() {}
+
+// Reads an environment variable 'name' and stores its string value in the
+// 'buffer' of 'buffer_size' bytes on success.
+//
+// - Returns the length of the environment variable string value (excluding
+// the terminating 0) if it is set and its contents (including the terminating
+// 0) can be stored in the 'buffer' without truncation.
+//
+// - Returns negated length of environment variable string value and writes
+// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to
+// store the value (including the terminating 0) without truncation.
+//
+// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment
+// variable is not set.
+//
+// - Returns INT_MIN if the 'name' is NULL.
+//
+// - Returns INT_MIN if the 'buffer_size' is negative.
+//
+// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than
+// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to
+// retrieve the length of the environment variable value string.
+//
+int getenv(const char *name, char *buffer, int buffer_size);
+// Reads an integer from the environment
+int getenv_int(const char *name, int default_value = 0);
+bool jit_dump_enabled();
+FILE *fopen(const char *filename, const char *mode);
+
+constexpr int msan_enabled = MSAN_ENABLED;
+inline void msan_unpoison(void *ptr, size_t size) {
+#if MSAN_ENABLED
+ __msan_unpoison(ptr, size);
+#endif
+}
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp
new file mode 100644
index 0000000000..89a57772cf
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp
@@ -0,0 +1,665 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <stdlib.h>
+#ifndef _WIN32
+#include <sys/time.h>
+#endif
+
+#include "mkldnn.h"
+#include "mkldnn_version.h"
+#include "c_types_map.hpp"
+#include "verbose.hpp"
+#include "cpu/cpu_isa_traits.hpp"
+
+#include "batch_normalization_pd.hpp"
+#include "pooling_pd.hpp"
+#include "concat_pd.hpp"
+#include "reorder_pd.hpp"
+#include "convolution_pd.hpp"
+#include "rnn_pd.hpp"
+#include "deconvolution_pd.hpp"
+#include "shuffle_pd.hpp"
+#include "eltwise_pd.hpp"
+#include "softmax_pd.hpp"
+#include "inner_product_pd.hpp"
+#include "sum_pd.hpp"
+#include "lrn_pd.hpp"
+
+/* MKL-DNN CPU ISA info */
+#define ISA_ANY "No instruction set specific optimizations"
+#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)"
+#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)"
+#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)"
+#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512)"
+#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions"
+#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \
+ "AVX512-DL Boost)"
+#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions"
+#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions"
+
+namespace mkldnn {
+namespace impl {
+
+static verbose_t verbose;
+static bool initialized;
+static bool version_printed = false;
+
+const verbose_t *mkldnn_verbose() {
+#if !defined(DISABLE_VERBOSE)
+ if (!initialized) {
+ const int len = 2;
+ char val[len] = {0};
+ if (getenv("MKLDNN_VERBOSE", val, len) == 1)
+ verbose.level = atoi(val);
+ initialized = true;
+ }
+ if (!version_printed && verbose.level > 0) {
+ printf("mkldnn_verbose,info,"
+ "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n",
+ mkldnn_version()->major, mkldnn_version()->minor,
+ mkldnn_version()->patch, mkldnn_version()->hash,
+ get_isa_info());
+ version_printed = true;
+ }
+#else
+ verbose.level = 0;
+#endif
+ return &verbose;
+}
+
+double get_msec() {
+#ifdef _WIN32
+ static LARGE_INTEGER frequency;
+ if (frequency.QuadPart == 0)
+ QueryPerformanceFrequency(&frequency);
+ LARGE_INTEGER now;
+ QueryPerformanceCounter(&now);
+ return 1e+3 * now.QuadPart / frequency.QuadPart;
+#else
+ struct timeval time;
+ gettimeofday(&time, NULL);
+ return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec;
+#endif
+}
+
+const char *get_isa_info() {
+ using namespace mkldnn::impl::cpu;
+ if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS;
+ if (mayiuse(avx512_mic)) return AVX512_MIC;
+ if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI;
+ if (mayiuse(avx512_core)) return AVX512_CORE;
+ if (mayiuse(avx512_common)) return AVX512_COMMON;
+ if (mayiuse(avx2)) return AVX2;
+ if (mayiuse(avx)) return AVX;
+ if (mayiuse(sse42)) return SSE42;
+ return ISA_ANY;
+}
+
+/* init_info section */
+namespace {
+#if !defined(DISABLE_VERBOSE)
+#define MKLDNN_VERBOSE_DAT_LEN 256
+#define MKLDNN_VERBOSE_AUX_LEN 384
+#define MKLDNN_VERBOSE_PRB_LEN 384
+
+#define DECL_DAT_AUX_PRB_STRS() \
+ int dat_written = 0, aux_written = 0, prb_written = 0; \
+ MAYBE_UNUSED((dat_written * aux_written * prb_written)); \
+ char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \
+ char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \
+ char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str)
+
+#define DFMT "%" PRId64
+
+void clear_buf(char *buf, int &written) {
+ /* TODO: do it better */
+ buf[0] = '#';
+ buf[1] = '\0';
+ written = 1;
+}
+
+#define DPRINT(buf, buf_len, written, ...) do { \
+ int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \
+ if (l < 0 || written + l > buf_len) { \
+ clear_buf(buf, written); \
+ } else { \
+ written += l; \
+ } \
+} while(0)
+
+// XXX: Outputs strings corresponding to memory formats used for data tensors.
+void format_prb_desc_str(char *str, int len, const memory_desc_t *md) {
+ const auto dims = md->dims;
+ int written = 0;
+ if (md->ndims == 1)
+ DPRINT(str, len, written,
+ "x" DFMT, dims[0]);
+ else if (md->ndims == 2)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT, dims[0], dims[1]);
+ else if (md->ndims == 3)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT "iw" DFMT,
+ dims[0], dims[1], dims[2]);
+ else if (md->ndims == 4)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT,
+ dims[0], dims[1], dims[2], dims[3]);
+ else if (md->ndims == 5)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT,
+ dims[0], dims[1], dims[2], dims[3], dims[4]);
+ else
+ mkldnn_md2dim_str(str, len, md);
+}
+
+void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind,
+ const char *impl_str, mkldnn_prop_kind_t prop_kind,
+ const char *data_str, const char *aux_str, const char *prb_str) {
+ MAYBE_UNUSED(verbose_templ);
+ int written = 0;
+ DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s",
+ mkldnn_prim_kind2str(prim_kind), impl_str,
+ mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_bnorm(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "flags:%u", s->desc()->flags);
+
+ format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_conv(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->desc()->prop_kind == prop_kind::backward_data
+ ? s->diff_src_md() : s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md() : s->weights_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // bia
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md(1) : s->weights_md(1);
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+ if (1) { // dst
+ auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ if (s->ndims() == 5) {
+ if (s->with_groups())
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT
+ "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->G(), s->IC(), s->OC(),
+ s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ else
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_ic" DFMT "oc" DFMT
+ "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->IC(), s->OC(),
+ s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ } else {
+ if (s->with_groups())
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->G(), s->IC(), s->OC(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ else
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_ic" DFMT "oc" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->IC(), s->OC(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ }
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_shuffle(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md();
+
+ if (1) { // data
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "axis:%d group_size:" DFMT, s->axis(), s->group_size());
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md);
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_eltwise(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_iprod(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->desc()->prop_kind == prop_kind::backward_data
+ ? s->diff_src_md() : s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md() : s->weights_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // bia
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md(1) : s->weights_md(1);
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+ if (1) { // dst
+ auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_lrn(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_mem(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst
+ auto md = s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "num:%d", s->n_inputs());
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_pool(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->is_fwd() ? s->src_md() : s->diff_src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst
+ auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // ws
+ auto md = s->workspace_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ if (s->is_3d()) {
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "ic" DFMT "_"
+ "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_"
+ "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_"
+ "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "",
+ s->MB(), s->C(),
+ s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->padL());
+ } else {
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "ic" DFMT "_"
+ "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_"
+ "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT,
+ s->MB(), s->C(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->padL());
+ }
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_softmax(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_rnn(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src layer
+ auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // src iter
+ auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei_layer
+ auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei_iter
+ auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // bias
+ auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst layer
+ auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst iter
+ auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ alg_kind_t alg_kind = s->cell_kind();
+ rnn_direction_t rnn_dir = s->direction();
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s_%s", mkldnn_alg_kind2str(alg_kind),
+ mkldnn_rnn_direction2str(rnn_dir));
+
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "l" DFMT "t" DFMT "mb" DFMT
+ "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT,
+ s->L(), s->T(), s->MB(),
+ s->SIC(), s->SLC(), s->DIC(), s->DLC());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+#undef DPRINT
+
+#else // !defined(DISABLE_VERBOSE)
+
+#define DEFINE_STUB(name) \
+ template <typename pd_t> \
+ static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \
+ { UNUSED(s); UNUSED(buffer); }
+
+DEFINE_STUB(bnorm);
+DEFINE_STUB(conv);
+DEFINE_STUB(eltwise);
+DEFINE_STUB(iprod);
+DEFINE_STUB(lrn);
+DEFINE_STUB(mem);
+DEFINE_STUB(pool);
+DEFINE_STUB(softmax);
+DEFINE_STUB(rnn);
+DEFINE_STUB(shuffle);
+#undef DEFINE_STUB
+
+#endif // !defined(DISABLE_VERBOSE)
+}
+
+void init_info(batch_normalization_pd_t *s, char *b)
+{ init_info_bnorm(s, b); }
+void init_info(concat_pd_t *s, char *b)
+{ init_info_mem(s, b); }
+void init_info(convolution_pd_t *s, char *b)
+{ init_info_conv(s, b); }
+void init_info(deconvolution_pd_t *s, char *b)
+{ init_info_conv(s, b); }
+void init_info(eltwise_pd_t *s, char *b)
+{ init_info_eltwise(s, b); }
+void init_info(inner_product_pd_t *s, char *b)
+{ init_info_iprod(s, b); }
+void init_info(lrn_pd_t *s, char *b)
+{ init_info_lrn(s, b); }
+void init_info(pooling_pd_t *s, char *b)
+{ init_info_pool(s, b); }
+void init_info(reorder_pd_t *s, char *b)
+{ init_info_mem(s, b); }
+void init_info(rnn_pd_t *s, char *b)
+{ init_info_rnn(s, b); }
+void init_info(shuffle_pd_t *s, char *b)
+{ init_info_shuffle(s, b); }
+void init_info(softmax_pd_t *s, char *b)
+{ init_info_softmax(s, b); }
+void init_info(sum_pd_t *s, char *b)
+{ init_info_mem(s, b); }
+
+}
+}
+
+mkldnn_status_t mkldnn_set_verbose(int level) {
+ using namespace mkldnn::impl::status;
+ if (level < 0 || level > 2) return invalid_arguments;
+ mkldnn::impl::verbose.level = level;
+ mkldnn::impl::initialized = true;
+ return success;
+}
+
+const mkldnn_version_t *mkldnn_version() {
+ static mkldnn_version_t ver = {
+ MKLDNN_VERSION_MAJOR,
+ MKLDNN_VERSION_MINOR,
+ MKLDNN_VERSION_PATCH,
+ MKLDNN_VERSION_HASH};
+ return &ver;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp
new file mode 100644
index 0000000000..e3049750cb
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp
@@ -0,0 +1,62 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef VERBOSE_HPP
+#define VERBOSE_HPP
+
+#include <stdio.h>
+#include <cinttypes>
+
+#include "mkldnn_debug.h"
+#include "c_types_map.hpp"
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct verbose_t {
+ int level;
+};
+
+const verbose_t *mkldnn_verbose();
+double get_msec();
+const char *get_isa_info();
+
+#if !defined(DISABLE_VERBOSE)
+#define MKLDNN_VERBOSE_BUF_LEN 1024
+#else
+#define MKLDNN_VERBOSE_BUF_LEN 1
+#endif
+
+void init_info(batch_normalization_pd_t *s, char *buffer);
+void init_info(concat_pd_t *s, char *buffer);
+void init_info(convolution_pd_t *s, char *buffer);
+void init_info(deconvolution_pd_t *s, char *buffer);
+void init_info(eltwise_pd_t *s, char *buffer);
+void init_info(inner_product_pd_t *s, char *buffer);
+void init_info(lrn_pd_t *s, char *buffer);
+void init_info(pooling_pd_t *s, char *buffer);
+void init_info(reorder_pd_t *s, char *buffer);
+void init_info(rnn_pd_t *s, char *buffer);
+void init_info(shuffle_pd_t *s, char *buffer);
+void init_info(softmax_pd_t *s, char *buffer);
+void init_info(sum_pd_t *s, char *buffer);
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp
new file mode 100644
index 0000000000..520bd4710b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp
@@ -0,0 +1,46 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef Z_MAGIC_HPP
+#define Z_MAGIC_HPP
+
+#define CHAIn2(a,b) a b
+#define CHAIN2(a,b) CHAIn2(a,b)
+
+#define CONCAt2(a,b) a ## b
+#define CONCAT2(a,b) CONCAt2(a,b)
+
+#define STRINGIFy(s) #s
+#define STRINGIFY(s) STRINGIFy(s)
+
+#ifdef _MSC_VER
+# define PRAGMA_MACRo(x) __pragma(x)
+# define PRAGMA_MACRO(x) PRAGMA_MACRo(x)
+#else
+# define PRAGMA_MACRo(x) _Pragma(#x)
+# define PRAGMA_MACRO(x) PRAGMA_MACRo(x)
+#endif
+
+#define UNUSED(x) ((void)x)
+#define MAYBE_UNUSED(x) UNUSED(x)
+
+#if defined(_WIN32) && !defined(__GNUC__)
+#define __PRETTY_FUNCTION__ __FUNCSIG__
+#endif
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s