diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/include/mkldnn.hpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/include/mkldnn.hpp | 2615 |
1 files changed, 2615 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp new file mode 100644 index 0000000000..581400a013 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp @@ -0,0 +1,2615 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include <stdlib.h> +#include <memory> +#include <vector> +#include <unordered_map> +#include <algorithm> +#include <iterator> + +#include "mkldnn.h" +#endif + +namespace mkldnn { + +/// @addtogroup cpp_api C++ API +/// @{ + +/// @addtogroup cpp_api_utils Utils +/// @{ + +/// A class that provides the destructor for an Intel(R) MKL-DNN C handle +template <typename T> class handle_traits {}; + +/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base +/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and +/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class +/// can be passed by value. This class enables wrapping: +/// - Newly constructed handles. +/// @n In this case, the constructed handle uses reference counting provided +/// by @p std::shared_ptr with a proper deleter function specified through +/// the @p handle_traits class. +/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for +/// example, through mkldnn_primitive_get_primitive_desc()). +/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a +/// deleter because it is assumed that the handle wrapper for the original +/// object deletes the handle (this model is similar to @p std::weak_ptr). +template <typename T, typename traits=handle_traits<T>> class handle { +private: + std::shared_ptr<typename std::remove_pointer<T>::type> _data; + handle(const handle &&) = delete; + handle &operator=(const handle &&other) = delete; +protected: + bool operator==(const T other) const { return other == _data.get(); } + bool operator!=(const T other) const { return !(*this == other); } +public: + /// Constructs a C handle wrapper. + /// @param t The C handle to wrap. + /// @param weak A flag to specify whether to construct a weak wrapper. + handle(T t = 0, bool weak = false): _data(0) { + reset(t, weak); + } + + handle(const handle &other): _data(other._data) {} + handle &operator=(const handle &other) { + _data = other._data; + return *this; + } + /// Resets the value of a C handle. + /// @param t The new value of the C handle. + /// @param weak A flag to specify whether the wrapper should be weak. + void reset(T t, bool weak = false) { + auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; + _data.reset(t, weak ? dummy_destructor : traits::destructor); + } + + /// Returns the value of the underlying C handle. + T get() const { return _data.get(); } + + bool operator==(const handle &other) const { return other._data.get() == _data.get(); } + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_memory_t> { + static constexpr auto destructor = &mkldnn_memory_destroy; +}; + +template <> struct handle_traits<mkldnn_primitive_desc_t> { + static constexpr auto destructor = &mkldnn_primitive_desc_destroy; +}; + +template <> struct handle_traits<mkldnn_primitive_t> { + static constexpr auto destructor = &mkldnn_primitive_destroy; +}; + +template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> { + static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; +}; +#endif + +struct memory; +struct primitive_desc; + +/// Base class for all computational primitives. +class primitive: public handle<mkldnn_primitive_t> { + friend struct error; + friend struct stream; + using handle::handle; +public: + /// A proxy to C primitive kind enum + enum class kind { + undefined_primitive = mkldnn_undefined_primitive, + reorder = mkldnn_reorder, + concat = mkldnn_concat, + sum = mkldnn_sum, + convolution = mkldnn_convolution, + deconvolution = mkldnn_deconvolution, + shuffle = mkldnn_shuffle, + eltwise = mkldnn_eltwise, + softmax = mkldnn_softmax, + pooling = mkldnn_pooling, + lrn = mkldnn_lrn, + batch_normalization = mkldnn_batch_normalization, + inner_product = mkldnn_inner_product, + rnn = mkldnn_rnn, + }; + + primitive(const_mkldnn_primitive_desc_t c_pd); + primitive(const primitive_desc &pd); + + /// Returns the descriptor of the underlying C API primitive. + inline const_mkldnn_primitive_desc_t get_primitive_desc() const; + // TODO: use the C++ API wrapper structure. + + void execute(struct stream &astream, + const std::unordered_map<int, memory> &args) const; +}; + +inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast<mkldnn_primitive_kind_t>(akind); +} +/// Intel(R) MKL-DNN exception class. +/// +/// This class captures the status returned by the failed C API function, error +/// message, and, optionally, handle of the primitive that caused the error. +struct error: public std::exception { + mkldnn_status_t status; + const char *message; + + /// Constructs an error instance. + /// + /// @param astatus The error status returned by the C API. + /// @param amessage The error message. + error(mkldnn_status_t astatus, const char *amessage) + : status(astatus), message(amessage) {} + + /// A convenience function for wrapping calls to the C API. Checks the + /// return status and throws an #error in case of failure. + /// + /// @param status The error status returned by the C API. + /// @param message The error message. + static void wrap_c_api(mkldnn_status_t status, const char *message) { + if (status != mkldnn_success) + throw error(status, message); + } +}; + +const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { + const_mkldnn_primitive_desc_t pd; + error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), + "could not get primitive descriptor by primitive"); + return pd; +} +/// @} + +/// @addtogroup cpp_api_enums Common data types and enumerations +/// A proxy to @ref c_api_types in @ref c_api. +/// +/// @{ + +enum scratchpad_mode { + scratchpad_mode_library = mkldnn_scratchpad_mode_library, + scratchpad_mode_user = mkldnn_scratchpad_mode_user, +}; + +inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { + return static_cast<mkldnn_scratchpad_mode_t>(mode); +} + +enum padding_kind { + zero = mkldnn_padding_zero +}; + +inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { + return static_cast<mkldnn_padding_kind_t>(kind); +} + +enum prop_kind { + forward_training = mkldnn_forward_training, + forward_scoring = mkldnn_forward_scoring, + forward_inference = mkldnn_forward_inference, + forward = mkldnn_forward, + backward = mkldnn_backward, + backward_data = mkldnn_backward_data, + backward_weights = mkldnn_backward_weights, + backward_bias = mkldnn_backward_bias +}; + +inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { + return static_cast<mkldnn_prop_kind_t>(kind); +} + +enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, + convolution_auto = mkldnn_convolution_auto, + convolution_direct = mkldnn_convolution_direct, + convolution_winograd = mkldnn_convolution_winograd, + deconvolution_direct = mkldnn_deconvolution_direct, + deconvolution_winograd = mkldnn_deconvolution_winograd, + eltwise_relu = mkldnn_eltwise_relu, + eltwise_tanh = mkldnn_eltwise_tanh, + eltwise_elu = mkldnn_eltwise_elu, + eltwise_square = mkldnn_eltwise_square, + eltwise_abs = mkldnn_eltwise_abs, + eltwise_sqrt = mkldnn_eltwise_sqrt, + eltwise_linear = mkldnn_eltwise_linear, + eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, + eltwise_soft_relu = mkldnn_eltwise_soft_relu, + eltwise_logistic = mkldnn_eltwise_logistic, + lrn_across_channels = mkldnn_lrn_across_channels, + lrn_within_channel = mkldnn_lrn_within_channel, + pooling_max = mkldnn_pooling_max, + pooling_avg = mkldnn_pooling_avg, + pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, + gru_linear_before_reset = mkldnn_gru_linear_before_reset +}; + +inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast<mkldnn_alg_kind_t>(aalgorithm); +} + +enum batch_normalization_flag { + use_global_stats = mkldnn_use_global_stats, + use_scale_shift = mkldnn_use_scaleshift, + fuse_bn_relu = mkldnn_fuse_bn_relu +}; + +inline mkldnn_batch_normalization_flag_t convert_to_c( + batch_normalization_flag aflag) { + return static_cast<mkldnn_batch_normalization_flag_t>(aflag); +} + +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast<mkldnn_rnn_direction_t>(adir); +} + +enum query { + undef = mkldnn_query_undef, + + query_engine = mkldnn_query_engine, + primitive_kind = mkldnn_query_primitive_kind, + + num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, + num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, + + time_estimate_f64 = mkldnn_query_time_estimate_f64, + memory_consumption_s64 = mkldnn_query_memory_consumption_s64, + + query_scratchpad_engine = mkldnn_query_scratchpad_engine, + + impl_info_str = mkldnn_query_impl_info_str, + + op_d = mkldnn_query_op_d, + convolution_d = mkldnn_query_convolution_d, + deconvolution_d = mkldnn_query_deconvolution_d, + shuffle_d = mkldnn_query_shuffle_d, + eltwise_d = mkldnn_query_eltwise_d, + softmax_d = mkldnn_query_softmax_d, + pooling_d = mkldnn_query_pooling_d, + lrn_d = mkldnn_query_lrn_d, + batch_normalization_d = mkldnn_query_batch_normalization_d, + inner_product_d = mkldnn_query_inner_product_d, + rnn_d = mkldnn_query_rnn_d, + + src_md = mkldnn_query_src_md, + diff_src_md = mkldnn_query_diff_src_md, + weights_md = mkldnn_query_weights_md, + diff_weights_md = mkldnn_query_diff_weights_md, + dst_md = mkldnn_query_dst_md, + diff_dst_md = mkldnn_query_diff_dst_md, + workspace_md = mkldnn_query_workspace_md, + scratchpad_md = mkldnn_query_scratchpad_md, +}; + +inline mkldnn_query_t convert_to_c(query aquery) { + return static_cast<mkldnn_query_t>(aquery); +} + +/// @} + +/// @addtogroup cpp_api_attr Attributes +/// An extension for controlling primitive behavior. +/// +/// @sa @ref c_api_attributes in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_post_ops_t> { + static constexpr auto destructor = &mkldnn_post_ops_destroy; +}; +#endif + +struct post_ops: public handle<mkldnn_post_ops_t> { + post_ops() { + mkldnn_post_ops_t result; + error::wrap_c_api(mkldnn_post_ops_create(&result), + "could not create post operation sequence"); + reset(result); + } + + int len() const { return mkldnn_post_ops_len(get()); } + + primitive::kind kind(int index) const { + error::wrap_c_api( + index < len() ? mkldnn_success : mkldnn_invalid_arguments, + "post_ops index is out of range"); + return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(), + index)); + } + + void append_sum(float scale = 1.) { + error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), + "could not append sum"); + } + + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), + "could not get sum params"); + } + + void append_eltwise(float scale, algorithm alg, float alpha, + float beta) { + error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, + convert_to_c(alg), alpha, beta), + "could not append eltwise"); + } + + void get_params_eltwise(int index, float &scale, algorithm &alg, + float &alpha, float &beta) const { + mkldnn_alg_kind_t c_alg; + error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, + &scale, &c_alg, &alpha, &beta), + "could not get eltwise params"); + alg = static_cast<algorithm>(c_alg); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_primitive_attr_t> { + static constexpr auto destructor = &mkldnn_primitive_attr_destroy; +}; +#endif + +struct primitive_attr: public handle<mkldnn_primitive_attr_t> { + primitive_attr() { + mkldnn_primitive_attr_t result; + error::wrap_c_api(mkldnn_primitive_attr_create(&result), + "could not create a primitive attr"); + reset(result); + } + + scratchpad_mode get_scratchpad_mode() const { + mkldnn_scratchpad_mode_t result; + error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( + get(), &result), "could not get scratchpad mode"); + return scratchpad_mode(result); + } + + void set_scratchpad_mode(scratchpad_mode mode) { + error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( + get(), mkldnn::convert_to_c(mode)), + "could not set scratchpad mode"); + } + + void get_output_scales(int &mask, std::vector<float> &scales) const + { + mkldnn_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), + &count, &c_mask, &c_scales), + "could not get int output scales"); + scales.resize(count); + + mask = c_mask; + for (mkldnn_dim_t c = 0; c < count; ++c) + scales[c] = c_scales[c]; + } + + void set_output_scales(int mask, const std::vector<float> &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), + (mkldnn_dim_t)scales.size(), mask, &scales[0]), + "could not set int output scales"); + } + + const post_ops get_post_ops() const { + post_ops result; + const_mkldnn_post_ops_t c_result; + error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), + "could not get post operation sequence"); + result.reset(const_cast<mkldnn_post_ops_t>(c_result), true); + return result; + } + + void set_post_ops(post_ops ops) { + error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post operation sequence"); + } + + void set_rnn_data_qparams(const float scale, const float shift) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), + scale, shift), "could not set rnn data int scale/shift"); + } + + void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), + (int)scales.size(), mask, &scales[0]), + "could not set rnn weights int scales"); + } +}; + +/// @} + +/// @addtogroup cpp_api_engine Engine +/// Engine operations. +/// +/// @sa @ref c_api_engine in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_engine_t> { + static constexpr auto destructor = &mkldnn_engine_destroy; +}; +#endif + +/// An execution engine. +struct engine: public handle<mkldnn_engine_t> { + friend class primitive; + // gcc bug??? using handle::handle; + + /// Kinds of engines. + enum kind { + /// An unspecified engine + any = mkldnn_any_engine, + /// CPU engine + cpu = mkldnn_cpu, + }; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + + static size_t get_count(kind akind) { + return mkldnn_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + + engine(kind akind, size_t index) { + mkldnn_engine_t aengine; + error::wrap_c_api( + mkldnn_engine_create(&aengine, + convert_to_c(akind), index), + "could not create an engine"); + reset(aengine); + } + + explicit engine(const mkldnn_engine_t& aengine) + : handle(aengine, true) {} + + engine(const handle<mkldnn_primitive_desc_t> &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + reset(engine_q, true); + } + + template <class primitive_desc> + static engine query(const primitive_desc &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + + return engine(engine_q); + } + +private: + static mkldnn_engine_kind_t convert_to_c(kind akind) { + return static_cast<mkldnn_engine_kind_t>(akind); + } +}; + +/// @} + +/// @addtogroup cpp_api_stream Stream +/// Execution stream operations +/// +/// @sa @ref c_api_stream in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_stream_t> { + static constexpr auto destructor = &mkldnn_stream_destroy; +}; +#endif + +struct stream: public handle<mkldnn_stream_t> { + using handle::handle; + + enum: unsigned { + default_flags = mkldnn_stream_default_flags, + }; + + /// Constructs a stream. + stream(const engine &aengine, + unsigned flags = static_cast<unsigned>(default_flags)) { + mkldnn_stream_t astream; + error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), + "could not create a stream"); + reset(astream); + } +}; + +/// @} + +/// @addtogroup cpp_api_memory_related Memory and memory related operations +/// @{ + +/// @addtogroup cpp_api_memory Memory +/// A primitive to describe and store data. +/// +/// For more information, refer to @ref c_api_memory in @ref c_api. +/// @{ + +/// Memory that describes the data. +struct memory: public handle<mkldnn_memory_t> { + public: + typedef mkldnn_dim_t dim; + typedef std::vector<dim> dims; + + template <typename T> static void validate_dims(const std::vector<T> &v) { + if (v.size() > MKLDNN_MAX_NDIMS) + throw error(mkldnn_invalid_arguments, "invalid dimensions"); + } + + /// Data type specification. See #mkldnn_data_type_t for a detailed + /// description. + enum data_type { + data_undef = mkldnn_data_type_undef, + f32 = mkldnn_f32, + s32 = mkldnn_s32, + s8 = mkldnn_s8, + u8 = mkldnn_u8, + }; + + /// Memory format tag specification. See #mkldnn_format_tag_t + /// for a detailed description. + enum format_tag { + format_tag_undef = mkldnn_format_tag_undef, + any = mkldnn_format_tag_any, + a = mkldnn_a, + ab = mkldnn_ab, + abc = mkldnn_abc, + abcd = mkldnn_abcd, + abcde = mkldnn_abcde, + abcdef = mkldnn_abcdef, + abdec = mkldnn_abdec, + acb = mkldnn_acb, + acbde = mkldnn_acbde, + acdb = mkldnn_acdb, + acdeb = mkldnn_acdeb, + ba = mkldnn_ba, + bac = mkldnn_bac, + bacd = mkldnn_bacd, + bcda = mkldnn_bcda, + cba = mkldnn_cba, + cdba = mkldnn_cdba, + cdeba = mkldnn_cdeba, + decab = mkldnn_decab, + Abc16a = mkldnn_Abc16a, + ABc16a16b = mkldnn_ABc16a16b, + aBc16b = mkldnn_aBc16b, + ABc16b16a = mkldnn_ABc16b16a, + Abc4a = mkldnn_Abc4a, + aBc4b = mkldnn_aBc4b, + ABc4b16a4b = mkldnn_ABc4b16a4b, + ABc4b4a = mkldnn_ABc4b4a, + ABc8a16b2a = mkldnn_ABc8a16b2a, + ABc8a8b = mkldnn_ABc8a8b, + aBc8b = mkldnn_aBc8b, + ABc8b16a2b = mkldnn_ABc8b16a2b, + ABc8b8a = mkldnn_ABc8b8a, + Abcd16a = mkldnn_Abcd16a, + ABcd16a16b = mkldnn_ABcd16a16b, + aBcd16b = mkldnn_aBcd16b, + ABcd16b16a = mkldnn_ABcd16b16a, + aBCd16b16c = mkldnn_aBCd16b16c, + aBCd16c16b = mkldnn_aBCd16c16b, + Abcd4a = mkldnn_Abcd4a, + aBcd4b = mkldnn_aBcd4b, + ABcd4b16a4b = mkldnn_ABcd4b16a4b, + ABcd4b4a = mkldnn_ABcd4b4a, + aBCd4c16b4c = mkldnn_aBCd4c16b4c, + aBCd4c4b = mkldnn_aBCd4c4b, + ABcd8a16b2a = mkldnn_ABcd8a16b2a, + ABcd8a8b = mkldnn_ABcd8a8b, + aBcd8b = mkldnn_aBcd8b, + ABcd8b16a2b = mkldnn_ABcd8b16a2b, + aBCd8b16c2b = mkldnn_aBCd8b16c2b, + ABcd8b8a = mkldnn_ABcd8b8a, + aBCd8b8c = mkldnn_aBCd8b8c, + aBCd8c16b2c = mkldnn_aBCd8c16b2c, + aBCd8c8b = mkldnn_aBCd8c8b, + Abcde16a = mkldnn_Abcde16a, + ABcde16a16b = mkldnn_ABcde16a16b, + aBcde16b = mkldnn_aBcde16b, + ABcde16b16a = mkldnn_ABcde16b16a, + aBCde16b16c = mkldnn_aBCde16b16c, + aBCde16c16b = mkldnn_aBCde16c16b, + aBCde2c8b4c = mkldnn_aBCde2c8b4c, + Abcde4a = mkldnn_Abcde4a, + aBcde4b = mkldnn_aBcde4b, + ABcde4b4a = mkldnn_ABcde4b4a, + aBCde4b4c = mkldnn_aBCde4b4c, + aBCde4c16b4c = mkldnn_aBCde4c16b4c, + aBCde4c4b = mkldnn_aBCde4c4b, + Abcde8a = mkldnn_Abcde8a, + ABcde8a8b = mkldnn_ABcde8a8b, + aBcde8b = mkldnn_aBcde8b, + ABcde8b16a2b = mkldnn_ABcde8b16a2b, + aBCde8b16c2b = mkldnn_aBCde8b16c2b, + ABcde8b8a = mkldnn_ABcde8b8a, + aBCde8b8c = mkldnn_aBCde8b8c, + aBCde8c16b2c = mkldnn_aBCde8c16b2c, + aBCde8c8b = mkldnn_aBCde8c8b, + aBcdef16b = mkldnn_aBcdef16b, + aBCdef16b16c = mkldnn_aBCdef16b16c, + aBCdef16c16b = mkldnn_aBCdef16c16b, + aBcdef4b = mkldnn_aBcdef4b, + aBCdef4c4b = mkldnn_aBCdef4c4b, + aBCdef8b8c = mkldnn_aBCdef8b8c, + aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, + aBCdef8c8b = mkldnn_aBCdef8c8b, + aBdc16b = mkldnn_aBdc16b, + aBdc4b = mkldnn_aBdc4b, + aBdc8b = mkldnn_aBdc8b, + aBdec16b = mkldnn_aBdec16b, + aBdec4b = mkldnn_aBdec4b, + aBdec8b = mkldnn_aBdec8b, + aBdefc16b = mkldnn_aBdefc16b, + aBdefc4b = mkldnn_aBdefc4b, + aBdefc8b = mkldnn_aBdefc8b, + Acb16a = mkldnn_Acb16a, + Acb4a = mkldnn_Acb4a, + Acb8a = mkldnn_Acb8a, + aCBd16b16c = mkldnn_aCBd16b16c, + aCBde16b16c = mkldnn_aCBde16b16c, + Acdb16a = mkldnn_Acdb16a, + Acdb4a = mkldnn_Acdb4a, + Acdb8a = mkldnn_Acdb8a, + Acdeb16a = mkldnn_Acdeb16a, + Acdeb4a = mkldnn_Acdeb4a, + Acdeb8a = mkldnn_Acdeb8a, + BAc16a16b = mkldnn_BAc16a16b, + BAcd16a16b = mkldnn_BAcd16a16b, + format_tag_last = mkldnn_format_tag_last, + + x = mkldnn_x, + nc = mkldnn_nc, + cn = mkldnn_cn, + ncw = mkldnn_ncw, + nwc = mkldnn_nwc, + nchw = mkldnn_nchw, + nhwc = mkldnn_nhwc, + chwn = mkldnn_chwn, + ncdhw = mkldnn_ncdhw, + ndhwc = mkldnn_ndhwc, + oi = mkldnn_oi, + io = mkldnn_io, + oiw = mkldnn_oiw, + wio = mkldnn_wio, + oihw = mkldnn_oihw, + hwio = mkldnn_hwio, + ihwo = mkldnn_ihwo, + iohw = mkldnn_iohw, + oidhw = mkldnn_oidhw, + dhwio = mkldnn_dhwio, + goiw = mkldnn_goiw, + goihw = mkldnn_goihw, + hwigo = mkldnn_hwigo, + giohw = mkldnn_giohw, + goidhw = mkldnn_goidhw, + tnc = mkldnn_tnc, + ntc = mkldnn_ntc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldgoi = mkldnn_ldgoi, + ldgo = mkldnn_ldgo, + nCdhw16c = mkldnn_nCdhw16c, + nCdhw4c = mkldnn_nCdhw4c, + nCdhw8c = mkldnn_nCdhw8c, + nChw16c = mkldnn_nChw16c, + nChw4c = mkldnn_nChw4c, + nChw8c = mkldnn_nChw8c, + nCw16c = mkldnn_nCw16c, + nCw4c = mkldnn_nCw4c, + nCw8c = mkldnn_nCw8c, + IOw16o16i = mkldnn_IOw16o16i, + OIw16i16o = mkldnn_OIw16i16o, + OIw16o16i = mkldnn_OIw16o16i, + Oiw16o = mkldnn_Oiw16o, + OIw4i16o4i = mkldnn_OIw4i16o4i, + OIw4i4o = mkldnn_OIw4i4o, + Oiw4o = mkldnn_Oiw4o, + OIw8i16o2i = mkldnn_OIw8i16o2i, + OIw8i8o = mkldnn_OIw8i8o, + OIw8o16i2o = mkldnn_OIw8o16i2o, + OIw8o8i = mkldnn_OIw8o8i, + Owi16o = mkldnn_Owi16o, + Owi4o = mkldnn_Owi4o, + Owi8o = mkldnn_Owi8o, + IOhw16o16i = mkldnn_IOhw16o16i, + Ohwi16o = mkldnn_Ohwi16o, + Ohwi4o = mkldnn_Ohwi4o, + Ohwi8o = mkldnn_Ohwi8o, + OIhw16i16o = mkldnn_OIhw16i16o, + OIhw16o16i = mkldnn_OIhw16o16i, + Oihw16o = mkldnn_Oihw16o, + OIhw4i16o4i = mkldnn_OIhw4i16o4i, + OIhw4i4o = mkldnn_OIhw4i4o, + Oihw4o = mkldnn_Oihw4o, + OIhw8i16o2i = mkldnn_OIhw8i16o2i, + OIhw8i8o = mkldnn_OIhw8i8o, + OIhw8o16i2o = mkldnn_OIhw8o16i2o, + OIhw8o8i = mkldnn_OIhw8o8i, + Odhwi16o = mkldnn_Odhwi16o, + Odhwi4o = mkldnn_Odhwi4o, + Odhwi8o = mkldnn_Odhwi8o, + OIdhw16i16o = mkldnn_OIdhw16i16o, + OIdhw16o16i = mkldnn_OIdhw16o16i, + Oidhw16o = mkldnn_Oidhw16o, + OIdhw4i4o = mkldnn_OIdhw4i4o, + Oidhw4o = mkldnn_Oidhw4o, + OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, + OIdhw8i8o = mkldnn_OIdhw8i8o, + OIdhw8o8i = mkldnn_OIdhw8o8i, + gIOw16o16i = mkldnn_gIOw16o16i, + gOIw16i16o = mkldnn_gOIw16i16o, + gOIw16o16i = mkldnn_gOIw16o16i, + gOiw16o = mkldnn_gOiw16o, + gOIw4i16o4i = mkldnn_gOIw4i16o4i, + gOIw4i4o = mkldnn_gOIw4i4o, + gOiw4o = mkldnn_gOiw4o, + gOIw8i16o2i = mkldnn_gOIw8i16o2i, + gOIw8i8o = mkldnn_gOIw8i8o, + gOIw8o16i2o = mkldnn_gOIw8o16i2o, + gOIw8o8i = mkldnn_gOIw8o8i, + gOwi16o = mkldnn_gOwi16o, + gOwi4o = mkldnn_gOwi4o, + gOwi8o = mkldnn_gOwi8o, + gIOhw16o16i = mkldnn_gIOhw16o16i, + gOhwi16o = mkldnn_gOhwi16o, + gOhwi4o = mkldnn_gOhwi4o, + gOhwi8o = mkldnn_gOhwi8o, + Goihw16g = mkldnn_Goihw16g, + gOIhw16i16o = mkldnn_gOIhw16i16o, + gOIhw16o16i = mkldnn_gOIhw16o16i, + gOihw16o = mkldnn_gOihw16o, + gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, + gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, + gOIhw4i4o = mkldnn_gOIhw4i4o, + gOIhw4o4i = mkldnn_gOIhw4o4i, + gOihw4o = mkldnn_gOihw4o, + Goihw8g = mkldnn_Goihw8g, + gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, + gOIhw8i8o = mkldnn_gOIhw8i8o, + gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, + gOIhw8o8i = mkldnn_gOIhw8o8i, + gOdhwi16o = mkldnn_gOdhwi16o, + gOdhwi4o = mkldnn_gOdhwi4o, + gOdhwi8o = mkldnn_gOdhwi8o, + gOIdhw16i16o = mkldnn_gOIdhw16i16o, + gOIdhw16o16i = mkldnn_gOIdhw16o16i, + gOidhw16o = mkldnn_gOidhw16o, + gOIdhw4i4o = mkldnn_gOIdhw4i4o, + gOidhw4o = mkldnn_gOidhw4o, + gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, + gOIdhw8i8o = mkldnn_gOIdhw8i8o, + gOIdhw8o8i = mkldnn_gOIdhw8o8i, + }; + + /// A memory descriptor. + struct desc { + friend struct memory; + /// The underlying C API data structure. + mkldnn_memory_desc_t data; + + /// Constructs a zero memory descriptor + desc(): data() {} + + /// Constructs a memory descriptor. + /// + /// @param adims Data dimensions + /// @param adata_type Data precision/type. + /// @param aformat Data layout format tag. + desc(const dims &adims, data_type adata_type, + format_tag aformat) { + validate_dims(adims); + error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), + adims.size() == 0 ? nullptr : &adims[0], + convert_to_c(adata_type), convert_to_c(aformat)), + "could not initialize a memory descriptor"); + } + + /// Constructs a memory descriptor from a C API data structure. + /// + /// @param adata A C API #mkldnn_memory_desc_t structure. + desc(const mkldnn_memory_desc_t &adata): data(adata) {} + + /// Constructs a sub-memory descriptor + // + /// @param adims Sizes of a sub-memory + /// @param offsets Offsets of a sub-memory + desc submemory_desc(const dims &adims, const dims &offsets) { + mkldnn_memory_desc_t sub_md; + error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, + &data, &adims[0], &offsets[0]), + "could not initialize a sub-memory"); + return desc(sub_md); + } + + /// Returns the number of bytes required to allocate the memory described + /// including the padding area. + size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } + + bool operator==(const desc &other) const { + return mkldnn_memory_desc_equal(&data, &other.data) != 0; + } + + bool operator!=(const desc &other) const { return !operator==(other); } + }; + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + /// @param ahandle Native handle. + memory(const desc &md, const engine &aengine, void *ahandle) { + mkldnn_memory_t result; + error::wrap_c_api(mkldnn_memory_create(&result, &md.data, + aengine.get(), ahandle), "could not create a memory"); + reset(result); + } + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + memory(const desc &md, const engine &aengine) + : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} + + /// Returns the descriptor of the memory. + desc get_desc() const { + const mkldnn_memory_desc_t *cdesc; + error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), + "could not get memory descriptor from a memory"); + return desc(*cdesc); + } + + /// Returns the engine of the memory. + engine get_engine() const { + mkldnn_engine_t engine_q; + error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), + "could not get engine from a memory"); + return engine(engine_q); + } + + /// Returns a handle of the data contained in the memory. + /// + /// On the CPU engine, this is a pointer to the allocated memory. + void *get_data_handle() const { + void *handle; + error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), + "could not get native handle"); + return handle; + } + + void set_data_handle(void *handle) const { + error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), + "could not set native handle"); + } + + // Must go away or be private: + static mkldnn_data_type_t convert_to_c(data_type adata_type) { + return static_cast<mkldnn_data_type_t>(adata_type); + } + static mkldnn_format_tag_t convert_to_c(format_tag aformat) { + return static_cast<mkldnn_format_tag_t>(aformat); + } +}; + +inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { + return !(a == b); +} + +inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { + return !(a == b); +} +inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { + return b == a; +} +inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { + return !(a == b); +} + +/// @} + +/// @addtogroup cpp_api_reorder Reorder +/// A primitive to copy data between memory formats. +/// +/// @sa @ref c_api_reorder in @ref c_api +/// @{ + +struct reorder : public primitive { + struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine scratchpad_engine() { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), + "could not get scratchpad engine from reorder primitive_desc"); + + return engine(engine_q); + } + + engine get_engine() { return engine::query(*this); } + }; + + reorder(const primitive_desc &pd): primitive(pd.get()) {} + + reorder(const memory &src, const memory &dst): + primitive(primitive_desc(src, dst).get()) {} + + void execute(stream astream, memory &src, memory &dst) { + primitive::execute(astream, + {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); + } +}; + +/// @} + +/// @addtogroup cpp_api_concat Concat +/// A primitive to concatenate data by arbitrary dimension. +/// +/// @sa @ref c_api_concat in @ref c_api +/// @{ + +struct concat : public primitive { + struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + std::vector<mkldnn_memory_desc_t> cpp_to_c( + const std::vector<memory::desc> &srcs) { + std::vector<mkldnn_memory_desc_t> c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, int concat_dimension, + const std::vector<memory::desc> &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + primitive_desc(int concat_dimension, + const std::vector<memory::desc> &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, nullptr, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + concat(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @addtogroup cpp_api_sum Sum +/// A primitive to sum data. +/// +/// @sa @ref c_api_sum in @ref c_api +/// @{ + +struct sum : public primitive { + struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + std::vector<mkldnn_memory_desc_t> cpp_to_c( + const std::vector<memory::desc> &srcs) { + std::vector<mkldnn_memory_desc_t> c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, + const std::vector<float> &scales, + const std::vector<memory::desc> &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + &scales[0], &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + primitive_desc(const std::vector<float> &scales, + const std::vector<memory::desc> &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, + nullptr, (int)c_api_srcs.size(), &scales[0], + &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + sum(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @} + +/// @addtogroup cpp_api_primitives Primitives +/// @{ + +/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors +/// @{ + +/// A base class for all primitive descriptors. +struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, + const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { + mkldnn_primitive_desc_iterator_t iterator = nullptr; + mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( + &iterator, desc, attr ? attr->get() : nullptr, e.get(), + hint_fwd_pd); + error::wrap_c_api(status, + "could not create a primitive descriptor iterator"); + pd_iterator.reset(iterator); + fetch_impl(); + } + + engine get_engine() { return engine::query(*this); } + + primitive_attr get_primitive_attr() const { + const_mkldnn_primitive_attr_t const_cattr; + error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), + "could not get attributes"); + mkldnn_primitive_attr_t cattr; + error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), + "could not clone attributes"); + + primitive_attr attr; + attr.reset(cattr); + return attr; + } + + /// Returns implementation name + const char *impl_info_str() const { + const char *res; + error::wrap_c_api(mkldnn_primitive_desc_query(get(), + mkldnn_query_impl_info_str, 0, &res), + "could not query implementation info string"); + return res; + } + + /// Queries the memory::dim value (same as int64_t) + memory::dim query_s64(query q) const { + memory::dim res; + mkldnn_status_t status = mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(q), 0, &res); + return status == mkldnn_success ? res : 0; + } + + /// Advances the next implementation for the given op descriptor. + /// + /// Returns: + /// - @c true on success + /// - @c false if the last implementation reached, and + /// the primitive descriptor itself is kept unchanged + bool next_impl() { + mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( + pd_iterator.get()); + if (status == mkldnn_iterator_ends) return false; + error::wrap_c_api(status, "primitive descriptor iterator next failed"); + + fetch_impl(); + return true; + } + + /// Queries and returns requested memory descriptor. + memory::desc query_md(query what, int idx = 0) const { + std::vector<query> valid_q{src_md, diff_src_md, weights_md, + diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), + [=](query q) { return what == q; })) + throw error(mkldnn_invalid_arguments, "invalid memory query"); + + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(what), idx); + if (cdesc == nullptr) return memory::desc(); + + return memory::desc(*cdesc); + } + + // register specialized queries, e.g. src_desc() +# define REG_QUERY_MD(name, what, idx) \ + memory::desc name ## _desc() const { return query_md(what ## _md, idx); } + + private: + handle<mkldnn_primitive_desc_iterator_t> pd_iterator; + void fetch_impl() { + mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( + pd_iterator.get()); + error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, + "could not fetch a primitive descriptor from the iterator"); + reset(pd); + } +}; + +/// @} + +/// @addtogroup cpp_api_convolution Convolution +/// A primitive to compute convolution using different algorithms. +/// +/// @sa @ref c_api_convolution in @ref c_api +/// @{ + +struct convolution_forward: public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_data : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_weights : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} +// +/// @addtogroup cpp_api_deconvolution Deconvolution +/// A primitive to compute deconvolution using different algorithms. +/// +/// @sa @ref c_api_deconvolution in @ref c_api +/// @{ + +struct deconvolution_forward: public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_data : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_weights : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_lrn LRN +/// A primitive to perform local response normalization (LRN) across or within +/// channels. +/// +/// @sa @ref c_api_lrn in @ref c_api +/// @{ + +struct lrn_forward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, local_size, alpha, beta, k), + "could not create a lrn forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct lrn_backward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(algorithm aalgorithm, const memory::desc &data_desc, + const memory::desc &diff_data_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), &diff_data_desc.data, + &data_desc.data, local_size, alpha, beta, k), + "could not create a lrn backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_pooling Pooling +/// A primitive to perform max or average pooling. +/// +/// @sa @ref c_api_pooling in @ref c_api +/// @{ + +struct pooling_forward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims kernel, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, &dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a forward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct pooling_backward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::dims &strides, + const memory::dims &kernel, + const memory::dims &padding_l, + const memory::dims &padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_src_desc.data, &diff_dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a backward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_eltwise Eltwise +/// A primitive to compute element-wise operations like parametric rectifier +/// linear unit (ReLU). +/// +/// @sa @ref c_api_eltwise in @ref c_api +/// @{ + +struct eltwise_forward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + template <typename T> + desc(prop_kind aprop_kind, algorithm alg_kind, + const memory::desc &src_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + mkldnn::convert_to_c(alg_kind), &src_desc.data, + static_cast<float>(alpha), static_cast<float>(beta)), + "could not create a eltwise forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct eltwise_backward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + + template <typename T> + desc(algorithm alg_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, + mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, + &data_desc.data, static_cast<float>(alpha), + static_cast<float>(beta)), + "could not create a eltwise backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_softmax Softmax +/// A primitive to perform softmax. +/// +/// @sa @ref c_api_softmax in @ref c_api +/// @{ + +struct softmax_forward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + softmax_axis), + "could not create a softmax forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct softmax_backward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(const memory::desc &diff_desc, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, + &diff_desc.data, &data_desc.data, softmax_axis), + "could not init a backward softmax descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_batch_norm Batch normalization +/// A primitive to perform batch normalization. +/// +/// @sa @ref c_api_batch_normalization in @ref c_api +/// @{ + +struct batch_normalization_forward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template <typename T> + desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + static_cast<float>(epsilon), flags), + "could not create a batch normalization forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + + memory::desc mean_desc() const { return stat_desc(mean); } + memory::desc variance_desc() const { return stat_desc(var); } + + private: + enum { mean = 1, var = 2, }; + memory::desc stat_desc(int kind) const { + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api(mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); + } + }; + + batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct batch_normalization_backward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template <typename T> + desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T epsilon, unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + &diff_data_desc.data, &data_desc.data, + static_cast<float>(epsilon), flags), + "could not create a batch normalization backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(mean, src, 1); + REG_QUERY_MD(variance, src, 2); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_inner_product Inner Product +/// A primitive to compute an inner product. +/// +/// @sa @ref c_api_inner_product in @ref c_api +/// @{ + +struct inner_product_forward: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, &bias_desc.data, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, nullptr, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_data: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_data_desc_init(&data, + &diff_src_desc.data, &weights_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_weights: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + &diff_bias_desc.data, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + nullptr, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_rnn RNN +/// A primitive to compute common recurrent layer. +/// +/// @sa @ref c_api_rnn in @ref c_api +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), 0U, 0, 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const + { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const + { return algorithm(c_rnn_cell_.activation_kind); } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc + ) { + error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + rnn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data, + &diff_src_layer_desc.data, &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, &diff_bias_desc.data, + &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src_layer, diff_src, 0); + REG_QUERY_MD(diff_src_iter, diff_src, 1); + REG_QUERY_MD(diff_weights_layer, diff_weights, 0); + REG_QUERY_MD(diff_weights_iter, diff_weights, 1); + REG_QUERY_MD(diff_bias, diff_weights, 2); + REG_QUERY_MD(diff_dst_layer, diff_dst, 0); + REG_QUERY_MD(diff_dst_iter, diff_dst, 1); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_shuffle Shuffle +/// A primitive to shuffle data along the axis. +/// +/// @sa @ref c_api_shuffle in @ref c_api +/// @{ + +struct shuffle_forward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + axis, group_size), + "could not create a shuffle forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct shuffle_backward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(const memory::desc &diff_data_desc, int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, + &diff_data_desc.data, axis, group_size), + "could not create a shuffle backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const shuffle_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @} Primitives + +/// @} C++ API + +#undef REG_QUERY_MD + +// implementation section +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { + mkldnn_primitive_t result; + error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), + "could not create a primitive"); + reset(result); +} + +inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} + +inline void primitive::execute(stream &astream, + const std::unordered_map<int, memory> &args) const { + std::vector<mkldnn_exec_arg_t> c_args; + c_args.reserve(args.size()); + for (const auto &a: args) + c_args.push_back({a.first, a.second.get()}); + + error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), + (int)c_args.size(), c_args.data()), + "primitive execution fail"); +} +#endif // DOXYGEN_SHOULD_SKIP_THIS + +} // namespace mkldnn + +#endif |