/******************************************************************************* * Copyright 2016-2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #ifndef MKLDNN_HPP #define MKLDNN_HPP #ifndef DOXYGEN_SHOULD_SKIP_THIS #include <stdlib.h> #include <memory> #include <vector> #include <unordered_map> #include <algorithm> #include <iterator> #include "mkldnn.h" #endif namespace mkldnn { /// @addtogroup cpp_api C++ API /// @{ /// @addtogroup cpp_api_utils Utils /// @{ /// A class that provides the destructor for an Intel(R) MKL-DNN C handle template <typename T> class handle_traits {}; /// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base /// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and /// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class /// can be passed by value. This class enables wrapping: /// - Newly constructed handles. /// @n In this case, the constructed handle uses reference counting provided /// by @p std::shared_ptr with a proper deleter function specified through /// the @p handle_traits class. /// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for /// example, through mkldnn_primitive_get_primitive_desc()). /// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a /// deleter because it is assumed that the handle wrapper for the original /// object deletes the handle (this model is similar to @p std::weak_ptr). template <typename T, typename traits=handle_traits<T>> class handle { private: std::shared_ptr<typename std::remove_pointer<T>::type> _data; handle(const handle &&) = delete; handle &operator=(const handle &&other) = delete; protected: bool operator==(const T other) const { return other == _data.get(); } bool operator!=(const T other) const { return !(*this == other); } public: /// Constructs a C handle wrapper. /// @param t The C handle to wrap. /// @param weak A flag to specify whether to construct a weak wrapper. handle(T t = 0, bool weak = false): _data(0) { reset(t, weak); } handle(const handle &other): _data(other._data) {} handle &operator=(const handle &other) { _data = other._data; return *this; } /// Resets the value of a C handle. /// @param t The new value of the C handle. /// @param weak A flag to specify whether the wrapper should be weak. void reset(T t, bool weak = false) { auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; _data.reset(t, weak ? dummy_destructor : traits::destructor); } /// Returns the value of the underlying C handle. T get() const { return _data.get(); } bool operator==(const handle &other) const { return other._data.get() == _data.get(); } bool operator!=(const handle &other) const { return !(*this == other); } }; #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits<mkldnn_memory_t> { static constexpr auto destructor = &mkldnn_memory_destroy; }; template <> struct handle_traits<mkldnn_primitive_desc_t> { static constexpr auto destructor = &mkldnn_primitive_desc_destroy; }; template <> struct handle_traits<mkldnn_primitive_t> { static constexpr auto destructor = &mkldnn_primitive_destroy; }; template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> { static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; }; #endif struct memory; struct primitive_desc; /// Base class for all computational primitives. class primitive: public handle<mkldnn_primitive_t> { friend struct error; friend struct stream; using handle::handle; public: /// A proxy to C primitive kind enum enum class kind { undefined_primitive = mkldnn_undefined_primitive, reorder = mkldnn_reorder, concat = mkldnn_concat, sum = mkldnn_sum, convolution = mkldnn_convolution, deconvolution = mkldnn_deconvolution, shuffle = mkldnn_shuffle, eltwise = mkldnn_eltwise, softmax = mkldnn_softmax, pooling = mkldnn_pooling, lrn = mkldnn_lrn, batch_normalization = mkldnn_batch_normalization, inner_product = mkldnn_inner_product, rnn = mkldnn_rnn, }; primitive(const_mkldnn_primitive_desc_t c_pd); primitive(const primitive_desc &pd); /// Returns the descriptor of the underlying C API primitive. inline const_mkldnn_primitive_desc_t get_primitive_desc() const; // TODO: use the C++ API wrapper structure. void execute(struct stream &astream, const std::unordered_map<int, memory> &args) const; }; inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { return static_cast<mkldnn_primitive_kind_t>(akind); } /// Intel(R) MKL-DNN exception class. /// /// This class captures the status returned by the failed C API function, error /// message, and, optionally, handle of the primitive that caused the error. struct error: public std::exception { mkldnn_status_t status; const char *message; /// Constructs an error instance. /// /// @param astatus The error status returned by the C API. /// @param amessage The error message. error(mkldnn_status_t astatus, const char *amessage) : status(astatus), message(amessage) {} /// A convenience function for wrapping calls to the C API. Checks the /// return status and throws an #error in case of failure. /// /// @param status The error status returned by the C API. /// @param message The error message. static void wrap_c_api(mkldnn_status_t status, const char *message) { if (status != mkldnn_success) throw error(status, message); } }; const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { const_mkldnn_primitive_desc_t pd; error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), "could not get primitive descriptor by primitive"); return pd; } /// @} /// @addtogroup cpp_api_enums Common data types and enumerations /// A proxy to @ref c_api_types in @ref c_api. /// /// @{ enum scratchpad_mode { scratchpad_mode_library = mkldnn_scratchpad_mode_library, scratchpad_mode_user = mkldnn_scratchpad_mode_user, }; inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { return static_cast<mkldnn_scratchpad_mode_t>(mode); } enum padding_kind { zero = mkldnn_padding_zero }; inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { return static_cast<mkldnn_padding_kind_t>(kind); } enum prop_kind { forward_training = mkldnn_forward_training, forward_scoring = mkldnn_forward_scoring, forward_inference = mkldnn_forward_inference, forward = mkldnn_forward, backward = mkldnn_backward, backward_data = mkldnn_backward_data, backward_weights = mkldnn_backward_weights, backward_bias = mkldnn_backward_bias }; inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { return static_cast<mkldnn_prop_kind_t>(kind); } enum algorithm { algorithm_undef = mkldnn_alg_kind_undef, convolution_auto = mkldnn_convolution_auto, convolution_direct = mkldnn_convolution_direct, convolution_winograd = mkldnn_convolution_winograd, deconvolution_direct = mkldnn_deconvolution_direct, deconvolution_winograd = mkldnn_deconvolution_winograd, eltwise_relu = mkldnn_eltwise_relu, eltwise_tanh = mkldnn_eltwise_tanh, eltwise_elu = mkldnn_eltwise_elu, eltwise_square = mkldnn_eltwise_square, eltwise_abs = mkldnn_eltwise_abs, eltwise_sqrt = mkldnn_eltwise_sqrt, eltwise_linear = mkldnn_eltwise_linear, eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, eltwise_soft_relu = mkldnn_eltwise_soft_relu, eltwise_logistic = mkldnn_eltwise_logistic, lrn_across_channels = mkldnn_lrn_across_channels, lrn_within_channel = mkldnn_lrn_within_channel, pooling_max = mkldnn_pooling_max, pooling_avg = mkldnn_pooling_avg, pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, vanilla_rnn = mkldnn_vanilla_rnn, vanilla_lstm = mkldnn_vanilla_lstm, vanilla_gru = mkldnn_vanilla_gru, gru_linear_before_reset = mkldnn_gru_linear_before_reset }; inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { return static_cast<mkldnn_alg_kind_t>(aalgorithm); } enum batch_normalization_flag { use_global_stats = mkldnn_use_global_stats, use_scale_shift = mkldnn_use_scaleshift, fuse_bn_relu = mkldnn_fuse_bn_relu }; inline mkldnn_batch_normalization_flag_t convert_to_c( batch_normalization_flag aflag) { return static_cast<mkldnn_batch_normalization_flag_t>(aflag); } enum rnn_direction { unidirectional_left2right = mkldnn_unidirectional_left2right, unidirectional_right2left = mkldnn_unidirectional_right2left, unidirectional = mkldnn_unidirectional, bidirectional_concat = mkldnn_bidirectional_concat, bidirectional_sum = mkldnn_bidirectional_sum, }; inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { return static_cast<mkldnn_rnn_direction_t>(adir); } enum query { undef = mkldnn_query_undef, query_engine = mkldnn_query_engine, primitive_kind = mkldnn_query_primitive_kind, num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, time_estimate_f64 = mkldnn_query_time_estimate_f64, memory_consumption_s64 = mkldnn_query_memory_consumption_s64, query_scratchpad_engine = mkldnn_query_scratchpad_engine, impl_info_str = mkldnn_query_impl_info_str, op_d = mkldnn_query_op_d, convolution_d = mkldnn_query_convolution_d, deconvolution_d = mkldnn_query_deconvolution_d, shuffle_d = mkldnn_query_shuffle_d, eltwise_d = mkldnn_query_eltwise_d, softmax_d = mkldnn_query_softmax_d, pooling_d = mkldnn_query_pooling_d, lrn_d = mkldnn_query_lrn_d, batch_normalization_d = mkldnn_query_batch_normalization_d, inner_product_d = mkldnn_query_inner_product_d, rnn_d = mkldnn_query_rnn_d, src_md = mkldnn_query_src_md, diff_src_md = mkldnn_query_diff_src_md, weights_md = mkldnn_query_weights_md, diff_weights_md = mkldnn_query_diff_weights_md, dst_md = mkldnn_query_dst_md, diff_dst_md = mkldnn_query_diff_dst_md, workspace_md = mkldnn_query_workspace_md, scratchpad_md = mkldnn_query_scratchpad_md, }; inline mkldnn_query_t convert_to_c(query aquery) { return static_cast<mkldnn_query_t>(aquery); } /// @} /// @addtogroup cpp_api_attr Attributes /// An extension for controlling primitive behavior. /// /// @sa @ref c_api_attributes in @ref c_api /// @{ #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits<mkldnn_post_ops_t> { static constexpr auto destructor = &mkldnn_post_ops_destroy; }; #endif struct post_ops: public handle<mkldnn_post_ops_t> { post_ops() { mkldnn_post_ops_t result; error::wrap_c_api(mkldnn_post_ops_create(&result), "could not create post operation sequence"); reset(result); } int len() const { return mkldnn_post_ops_len(get()); } primitive::kind kind(int index) const { error::wrap_c_api( index < len() ? mkldnn_success : mkldnn_invalid_arguments, "post_ops index is out of range"); return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(), index)); } void append_sum(float scale = 1.) { error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), "could not append sum"); } void get_params_sum(int index, float &scale) const { error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), "could not get sum params"); } void append_eltwise(float scale, algorithm alg, float alpha, float beta) { error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, convert_to_c(alg), alpha, beta), "could not append eltwise"); } void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const { mkldnn_alg_kind_t c_alg; error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, &scale, &c_alg, &alpha, &beta), "could not get eltwise params"); alg = static_cast<algorithm>(c_alg); } }; #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits<mkldnn_primitive_attr_t> { static constexpr auto destructor = &mkldnn_primitive_attr_destroy; }; #endif struct primitive_attr: public handle<mkldnn_primitive_attr_t> { primitive_attr() { mkldnn_primitive_attr_t result; error::wrap_c_api(mkldnn_primitive_attr_create(&result), "could not create a primitive attr"); reset(result); } scratchpad_mode get_scratchpad_mode() const { mkldnn_scratchpad_mode_t result; error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( get(), &result), "could not get scratchpad mode"); return scratchpad_mode(result); } void set_scratchpad_mode(scratchpad_mode mode) { error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( get(), mkldnn::convert_to_c(mode)), "could not set scratchpad mode"); } void get_output_scales(int &mask, std::vector<float> &scales) const { mkldnn_dim_t count; int c_mask; const float *c_scales; error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), &count, &c_mask, &c_scales), "could not get int output scales"); scales.resize(count); mask = c_mask; for (mkldnn_dim_t c = 0; c < count; ++c) scales[c] = c_scales[c]; } void set_output_scales(int mask, const std::vector<float> &scales) { error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), (mkldnn_dim_t)scales.size(), mask, &scales[0]), "could not set int output scales"); } const post_ops get_post_ops() const { post_ops result; const_mkldnn_post_ops_t c_result; error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), "could not get post operation sequence"); result.reset(const_cast<mkldnn_post_ops_t>(c_result), true); return result; } void set_post_ops(post_ops ops) { error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), "could not set post operation sequence"); } void set_rnn_data_qparams(const float scale, const float shift) { error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), scale, shift), "could not set rnn data int scale/shift"); } void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) { error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), (int)scales.size(), mask, &scales[0]), "could not set rnn weights int scales"); } }; /// @} /// @addtogroup cpp_api_engine Engine /// Engine operations. /// /// @sa @ref c_api_engine in @ref c_api /// @{ #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits<mkldnn_engine_t> { static constexpr auto destructor = &mkldnn_engine_destroy; }; #endif /// An execution engine. struct engine: public handle<mkldnn_engine_t> { friend class primitive; // gcc bug??? using handle::handle; /// Kinds of engines. enum kind { /// An unspecified engine any = mkldnn_any_engine, /// CPU engine cpu = mkldnn_cpu, }; /// Returns the number of engines of a certain kind. /// /// @param akind The kind of engines to count. static size_t get_count(kind akind) { return mkldnn_engine_get_count(convert_to_c(akind)); } /// Constructs an engine. /// /// @param akind The kind of engine to construct. /// @param index The index of the engine. Must be less than the value /// returned by #get_count() for this particular kind of engine. engine(kind akind, size_t index) { mkldnn_engine_t aengine; error::wrap_c_api( mkldnn_engine_create(&aengine, convert_to_c(akind), index), "could not create an engine"); reset(aengine); } explicit engine(const mkldnn_engine_t& aengine) : handle(aengine, true) {} engine(const handle<mkldnn_primitive_desc_t> &pd) { mkldnn_engine_t engine_q; error::wrap_c_api( mkldnn_primitive_desc_query(pd.get(), mkldnn::convert_to_c(query_engine), 0, &engine_q), "could not get engine from primitive_desc"); reset(engine_q, true); } template <class primitive_desc> static engine query(const primitive_desc &pd) { mkldnn_engine_t engine_q; error::wrap_c_api( mkldnn_primitive_desc_query(pd.get(), mkldnn::convert_to_c(query_engine), 0, &engine_q), "could not get engine from primitive_desc"); return engine(engine_q); } private: static mkldnn_engine_kind_t convert_to_c(kind akind) { return static_cast<mkldnn_engine_kind_t>(akind); } }; /// @} /// @addtogroup cpp_api_stream Stream /// Execution stream operations /// /// @sa @ref c_api_stream in @ref c_api /// @{ #ifndef DOXYGEN_SHOULD_SKIP_THIS template <> struct handle_traits<mkldnn_stream_t> { static constexpr auto destructor = &mkldnn_stream_destroy; }; #endif struct stream: public handle<mkldnn_stream_t> { using handle::handle; enum: unsigned { default_flags = mkldnn_stream_default_flags, }; /// Constructs a stream. stream(const engine &aengine, unsigned flags = static_cast<unsigned>(default_flags)) { mkldnn_stream_t astream; error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), "could not create a stream"); reset(astream); } }; /// @} /// @addtogroup cpp_api_memory_related Memory and memory related operations /// @{ /// @addtogroup cpp_api_memory Memory /// A primitive to describe and store data. /// /// For more information, refer to @ref c_api_memory in @ref c_api. /// @{ /// Memory that describes the data. struct memory: public handle<mkldnn_memory_t> { public: typedef mkldnn_dim_t dim; typedef std::vector<dim> dims; template <typename T> static void validate_dims(const std::vector<T> &v) { if (v.size() > MKLDNN_MAX_NDIMS) throw error(mkldnn_invalid_arguments, "invalid dimensions"); } /// Data type specification. See #mkldnn_data_type_t for a detailed /// description. enum data_type { data_undef = mkldnn_data_type_undef, f32 = mkldnn_f32, s32 = mkldnn_s32, s8 = mkldnn_s8, u8 = mkldnn_u8, }; /// Memory format tag specification. See #mkldnn_format_tag_t /// for a detailed description. enum format_tag { format_tag_undef = mkldnn_format_tag_undef, any = mkldnn_format_tag_any, a = mkldnn_a, ab = mkldnn_ab, abc = mkldnn_abc, abcd = mkldnn_abcd, abcde = mkldnn_abcde, abcdef = mkldnn_abcdef, abdec = mkldnn_abdec, acb = mkldnn_acb, acbde = mkldnn_acbde, acdb = mkldnn_acdb, acdeb = mkldnn_acdeb, ba = mkldnn_ba, bac = mkldnn_bac, bacd = mkldnn_bacd, bcda = mkldnn_bcda, cba = mkldnn_cba, cdba = mkldnn_cdba, cdeba = mkldnn_cdeba, decab = mkldnn_decab, Abc16a = mkldnn_Abc16a, ABc16a16b = mkldnn_ABc16a16b, aBc16b = mkldnn_aBc16b, ABc16b16a = mkldnn_ABc16b16a, Abc4a = mkldnn_Abc4a, aBc4b = mkldnn_aBc4b, ABc4b16a4b = mkldnn_ABc4b16a4b, ABc4b4a = mkldnn_ABc4b4a, ABc8a16b2a = mkldnn_ABc8a16b2a, ABc8a8b = mkldnn_ABc8a8b, aBc8b = mkldnn_aBc8b, ABc8b16a2b = mkldnn_ABc8b16a2b, ABc8b8a = mkldnn_ABc8b8a, Abcd16a = mkldnn_Abcd16a, ABcd16a16b = mkldnn_ABcd16a16b, aBcd16b = mkldnn_aBcd16b, ABcd16b16a = mkldnn_ABcd16b16a, aBCd16b16c = mkldnn_aBCd16b16c, aBCd16c16b = mkldnn_aBCd16c16b, Abcd4a = mkldnn_Abcd4a, aBcd4b = mkldnn_aBcd4b, ABcd4b16a4b = mkldnn_ABcd4b16a4b, ABcd4b4a = mkldnn_ABcd4b4a, aBCd4c16b4c = mkldnn_aBCd4c16b4c, aBCd4c4b = mkldnn_aBCd4c4b, ABcd8a16b2a = mkldnn_ABcd8a16b2a, ABcd8a8b = mkldnn_ABcd8a8b, aBcd8b = mkldnn_aBcd8b, ABcd8b16a2b = mkldnn_ABcd8b16a2b, aBCd8b16c2b = mkldnn_aBCd8b16c2b, ABcd8b8a = mkldnn_ABcd8b8a, aBCd8b8c = mkldnn_aBCd8b8c, aBCd8c16b2c = mkldnn_aBCd8c16b2c, aBCd8c8b = mkldnn_aBCd8c8b, Abcde16a = mkldnn_Abcde16a, ABcde16a16b = mkldnn_ABcde16a16b, aBcde16b = mkldnn_aBcde16b, ABcde16b16a = mkldnn_ABcde16b16a, aBCde16b16c = mkldnn_aBCde16b16c, aBCde16c16b = mkldnn_aBCde16c16b, aBCde2c8b4c = mkldnn_aBCde2c8b4c, Abcde4a = mkldnn_Abcde4a, aBcde4b = mkldnn_aBcde4b, ABcde4b4a = mkldnn_ABcde4b4a, aBCde4b4c = mkldnn_aBCde4b4c, aBCde4c16b4c = mkldnn_aBCde4c16b4c, aBCde4c4b = mkldnn_aBCde4c4b, Abcde8a = mkldnn_Abcde8a, ABcde8a8b = mkldnn_ABcde8a8b, aBcde8b = mkldnn_aBcde8b, ABcde8b16a2b = mkldnn_ABcde8b16a2b, aBCde8b16c2b = mkldnn_aBCde8b16c2b, ABcde8b8a = mkldnn_ABcde8b8a, aBCde8b8c = mkldnn_aBCde8b8c, aBCde8c16b2c = mkldnn_aBCde8c16b2c, aBCde8c8b = mkldnn_aBCde8c8b, aBcdef16b = mkldnn_aBcdef16b, aBCdef16b16c = mkldnn_aBCdef16b16c, aBCdef16c16b = mkldnn_aBCdef16c16b, aBcdef4b = mkldnn_aBcdef4b, aBCdef4c4b = mkldnn_aBCdef4c4b, aBCdef8b8c = mkldnn_aBCdef8b8c, aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, aBCdef8c8b = mkldnn_aBCdef8c8b, aBdc16b = mkldnn_aBdc16b, aBdc4b = mkldnn_aBdc4b, aBdc8b = mkldnn_aBdc8b, aBdec16b = mkldnn_aBdec16b, aBdec4b = mkldnn_aBdec4b, aBdec8b = mkldnn_aBdec8b, aBdefc16b = mkldnn_aBdefc16b, aBdefc4b = mkldnn_aBdefc4b, aBdefc8b = mkldnn_aBdefc8b, Acb16a = mkldnn_Acb16a, Acb4a = mkldnn_Acb4a, Acb8a = mkldnn_Acb8a, aCBd16b16c = mkldnn_aCBd16b16c, aCBde16b16c = mkldnn_aCBde16b16c, Acdb16a = mkldnn_Acdb16a, Acdb4a = mkldnn_Acdb4a, Acdb8a = mkldnn_Acdb8a, Acdeb16a = mkldnn_Acdeb16a, Acdeb4a = mkldnn_Acdeb4a, Acdeb8a = mkldnn_Acdeb8a, BAc16a16b = mkldnn_BAc16a16b, BAcd16a16b = mkldnn_BAcd16a16b, format_tag_last = mkldnn_format_tag_last, x = mkldnn_x, nc = mkldnn_nc, cn = mkldnn_cn, ncw = mkldnn_ncw, nwc = mkldnn_nwc, nchw = mkldnn_nchw, nhwc = mkldnn_nhwc, chwn = mkldnn_chwn, ncdhw = mkldnn_ncdhw, ndhwc = mkldnn_ndhwc, oi = mkldnn_oi, io = mkldnn_io, oiw = mkldnn_oiw, wio = mkldnn_wio, oihw = mkldnn_oihw, hwio = mkldnn_hwio, ihwo = mkldnn_ihwo, iohw = mkldnn_iohw, oidhw = mkldnn_oidhw, dhwio = mkldnn_dhwio, goiw = mkldnn_goiw, goihw = mkldnn_goihw, hwigo = mkldnn_hwigo, giohw = mkldnn_giohw, goidhw = mkldnn_goidhw, tnc = mkldnn_tnc, ntc = mkldnn_ntc, ldsnc = mkldnn_ldsnc, ldigo = mkldnn_ldigo, ldgoi = mkldnn_ldgoi, ldgo = mkldnn_ldgo, nCdhw16c = mkldnn_nCdhw16c, nCdhw4c = mkldnn_nCdhw4c, nCdhw8c = mkldnn_nCdhw8c, nChw16c = mkldnn_nChw16c, nChw4c = mkldnn_nChw4c, nChw8c = mkldnn_nChw8c, nCw16c = mkldnn_nCw16c, nCw4c = mkldnn_nCw4c, nCw8c = mkldnn_nCw8c, IOw16o16i = mkldnn_IOw16o16i, OIw16i16o = mkldnn_OIw16i16o, OIw16o16i = mkldnn_OIw16o16i, Oiw16o = mkldnn_Oiw16o, OIw4i16o4i = mkldnn_OIw4i16o4i, OIw4i4o = mkldnn_OIw4i4o, Oiw4o = mkldnn_Oiw4o, OIw8i16o2i = mkldnn_OIw8i16o2i, OIw8i8o = mkldnn_OIw8i8o, OIw8o16i2o = mkldnn_OIw8o16i2o, OIw8o8i = mkldnn_OIw8o8i, Owi16o = mkldnn_Owi16o, Owi4o = mkldnn_Owi4o, Owi8o = mkldnn_Owi8o, IOhw16o16i = mkldnn_IOhw16o16i, Ohwi16o = mkldnn_Ohwi16o, Ohwi4o = mkldnn_Ohwi4o, Ohwi8o = mkldnn_Ohwi8o, OIhw16i16o = mkldnn_OIhw16i16o, OIhw16o16i = mkldnn_OIhw16o16i, Oihw16o = mkldnn_Oihw16o, OIhw4i16o4i = mkldnn_OIhw4i16o4i, OIhw4i4o = mkldnn_OIhw4i4o, Oihw4o = mkldnn_Oihw4o, OIhw8i16o2i = mkldnn_OIhw8i16o2i, OIhw8i8o = mkldnn_OIhw8i8o, OIhw8o16i2o = mkldnn_OIhw8o16i2o, OIhw8o8i = mkldnn_OIhw8o8i, Odhwi16o = mkldnn_Odhwi16o, Odhwi4o = mkldnn_Odhwi4o, Odhwi8o = mkldnn_Odhwi8o, OIdhw16i16o = mkldnn_OIdhw16i16o, OIdhw16o16i = mkldnn_OIdhw16o16i, Oidhw16o = mkldnn_Oidhw16o, OIdhw4i4o = mkldnn_OIdhw4i4o, Oidhw4o = mkldnn_Oidhw4o, OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, OIdhw8i8o = mkldnn_OIdhw8i8o, OIdhw8o8i = mkldnn_OIdhw8o8i, gIOw16o16i = mkldnn_gIOw16o16i, gOIw16i16o = mkldnn_gOIw16i16o, gOIw16o16i = mkldnn_gOIw16o16i, gOiw16o = mkldnn_gOiw16o, gOIw4i16o4i = mkldnn_gOIw4i16o4i, gOIw4i4o = mkldnn_gOIw4i4o, gOiw4o = mkldnn_gOiw4o, gOIw8i16o2i = mkldnn_gOIw8i16o2i, gOIw8i8o = mkldnn_gOIw8i8o, gOIw8o16i2o = mkldnn_gOIw8o16i2o, gOIw8o8i = mkldnn_gOIw8o8i, gOwi16o = mkldnn_gOwi16o, gOwi4o = mkldnn_gOwi4o, gOwi8o = mkldnn_gOwi8o, gIOhw16o16i = mkldnn_gIOhw16o16i, gOhwi16o = mkldnn_gOhwi16o, gOhwi4o = mkldnn_gOhwi4o, gOhwi8o = mkldnn_gOhwi8o, Goihw16g = mkldnn_Goihw16g, gOIhw16i16o = mkldnn_gOIhw16i16o, gOIhw16o16i = mkldnn_gOIhw16o16i, gOihw16o = mkldnn_gOihw16o, gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, gOIhw4i4o = mkldnn_gOIhw4i4o, gOIhw4o4i = mkldnn_gOIhw4o4i, gOihw4o = mkldnn_gOihw4o, Goihw8g = mkldnn_Goihw8g, gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, gOIhw8i8o = mkldnn_gOIhw8i8o, gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, gOIhw8o8i = mkldnn_gOIhw8o8i, gOdhwi16o = mkldnn_gOdhwi16o, gOdhwi4o = mkldnn_gOdhwi4o, gOdhwi8o = mkldnn_gOdhwi8o, gOIdhw16i16o = mkldnn_gOIdhw16i16o, gOIdhw16o16i = mkldnn_gOIdhw16o16i, gOidhw16o = mkldnn_gOidhw16o, gOIdhw4i4o = mkldnn_gOIdhw4i4o, gOidhw4o = mkldnn_gOidhw4o, gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, gOIdhw8i8o = mkldnn_gOIdhw8i8o, gOIdhw8o8i = mkldnn_gOIdhw8o8i, }; /// A memory descriptor. struct desc { friend struct memory; /// The underlying C API data structure. mkldnn_memory_desc_t data; /// Constructs a zero memory descriptor desc(): data() {} /// Constructs a memory descriptor. /// /// @param adims Data dimensions /// @param adata_type Data precision/type. /// @param aformat Data layout format tag. desc(const dims &adims, data_type adata_type, format_tag aformat) { validate_dims(adims); error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), adims.size() == 0 ? nullptr : &adims[0], convert_to_c(adata_type), convert_to_c(aformat)), "could not initialize a memory descriptor"); } /// Constructs a memory descriptor from a C API data structure. /// /// @param adata A C API #mkldnn_memory_desc_t structure. desc(const mkldnn_memory_desc_t &adata): data(adata) {} /// Constructs a sub-memory descriptor // /// @param adims Sizes of a sub-memory /// @param offsets Offsets of a sub-memory desc submemory_desc(const dims &adims, const dims &offsets) { mkldnn_memory_desc_t sub_md; error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, &data, &adims[0], &offsets[0]), "could not initialize a sub-memory"); return desc(sub_md); } /// Returns the number of bytes required to allocate the memory described /// including the padding area. size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } bool operator==(const desc &other) const { return mkldnn_memory_desc_equal(&data, &other.data) != 0; } bool operator!=(const desc &other) const { return !operator==(other); } }; /// Constructs a memory. /// /// @param md Memory descriptor. /// @param aengine Engine. /// @param ahandle Native handle. memory(const desc &md, const engine &aengine, void *ahandle) { mkldnn_memory_t result; error::wrap_c_api(mkldnn_memory_create(&result, &md.data, aengine.get(), ahandle), "could not create a memory"); reset(result); } /// Constructs a memory. /// /// @param md Memory descriptor. /// @param aengine Engine. memory(const desc &md, const engine &aengine) : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} /// Returns the descriptor of the memory. desc get_desc() const { const mkldnn_memory_desc_t *cdesc; error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), "could not get memory descriptor from a memory"); return desc(*cdesc); } /// Returns the engine of the memory. engine get_engine() const { mkldnn_engine_t engine_q; error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), "could not get engine from a memory"); return engine(engine_q); } /// Returns a handle of the data contained in the memory. /// /// On the CPU engine, this is a pointer to the allocated memory. void *get_data_handle() const { void *handle; error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), "could not get native handle"); return handle; } void set_data_handle(void *handle) const { error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), "could not set native handle"); } // Must go away or be private: static mkldnn_data_type_t convert_to_c(data_type adata_type) { return static_cast<mkldnn_data_type_t>(adata_type); } static mkldnn_format_tag_t convert_to_c(format_tag aformat) { return static_cast<mkldnn_format_tag_t>(aformat); } }; inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { return a == memory::convert_to_c(b); } inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { return !(a == b); } inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { return b == a; } inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { return !(a == b); } inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { return a == memory::convert_to_c(b); } inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { return !(a == b); } inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { return b == a; } inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { return !(a == b); } /// @} /// @addtogroup cpp_api_reorder Reorder /// A primitive to copy data between memory formats. /// /// @sa @ref c_api_reorder in @ref c_api /// @{ struct reorder : public primitive { struct primitive_desc : public handle<mkldnn_primitive_desc_t> { primitive_desc(const engine &src_engine, const memory::desc &src_md, const engine &dst_engine, const memory::desc &dst_md, const primitive_attr &aattr) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, src_engine.get(), &src_md.data, dst_engine.get(), &dst_md.data, aattr.get()), "could not create a reorder primitive descriptor"); reset(result); } primitive_desc(const engine &src_engine, const memory::desc &src_md, const engine &dst_engine, const memory::desc &dst_md) { mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, src_engine.get(), &src_md.data, dst_engine.get(), &dst_md.data, nullptr), "could not create a reorder primitive descriptor"); reset(result); } primitive_desc(const memory &src, const memory &dst, const primitive_attr &aattr) { mkldnn_primitive_desc_t result; auto src_md = src.get_desc(); auto dst_md = dst.get_desc(); error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, src.get_engine().get(), &src_md.data, dst.get_engine().get(), &dst_md.data, aattr.get()), "could not create a reorder primitive descriptor"); reset(result); } primitive_desc(const memory &src, const memory &dst) { mkldnn_primitive_desc_t result; auto src_md = src.get_desc(); auto dst_md = dst.get_desc(); error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, src.get_engine().get(), &src_md.data, dst.get_engine().get(), &dst_md.data, nullptr), "could not create a reorder primitive descriptor"); reset(result); } memory::desc scratchpad_desc() const { const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( get(), mkldnn::convert_to_c(scratchpad_md), 0); if (cdesc == nullptr) return memory::desc(); return memory::desc(*cdesc); } engine scratchpad_engine() { mkldnn_engine_t engine_q; error::wrap_c_api( mkldnn_primitive_desc_query(get(), mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), "could not get scratchpad engine from reorder primitive_desc"); return engine(engine_q); } engine get_engine() { return engine::query(*this); } }; reorder(const primitive_desc &pd): primitive(pd.get()) {} reorder(const memory &src, const memory &dst): primitive(primitive_desc(src, dst).get()) {} void execute(stream astream, memory &src, memory &dst) { primitive::execute(astream, {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); } }; /// @} /// @addtogroup cpp_api_concat Concat /// A primitive to concatenate data by arbitrary dimension. /// /// @sa @ref c_api_concat in @ref c_api /// @{ struct concat : public primitive { struct primitive_desc : public handle<mkldnn_primitive_desc_t> { std::vector<mkldnn_memory_desc_t> cpp_to_c( const std::vector<memory::desc> &srcs) { std::vector<mkldnn_memory_desc_t> c_api_srcs; c_api_srcs.reserve(srcs.size()); for (const auto &s : srcs) c_api_srcs.push_back(s.data); return c_api_srcs; } primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector<memory::desc> &srcs, const engine &aengine) { auto c_api_srcs = cpp_to_c(srcs); mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_concat_primitive_desc_create( &result, &dst.data, (int)c_api_srcs.size(), concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), "could not create a concat primitive descriptor"); reset(result); } primitive_desc(int concat_dimension, const std::vector<memory::desc> &srcs, const engine &aengine) { auto c_api_srcs = cpp_to_c(srcs); mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_concat_primitive_desc_create( &result, nullptr, (int)c_api_srcs.size(), concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), "could not create a concat primitive descriptor"); reset(result); } memory::desc dst_desc() const { const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( get(), mkldnn::convert_to_c(dst_md), 0); error::wrap_c_api( cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, "could not get a dst memory descriptor"); return memory::desc(*cdesc); } memory::desc scratchpad_desc() const { const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( get(), mkldnn::convert_to_c(scratchpad_md), 0); if (cdesc == nullptr) return memory::desc(); return memory::desc(*cdesc); } engine get_engine() { return engine::query(*this); } }; concat(const primitive_desc &pd): primitive(pd.get()) {} }; /// @} /// @addtogroup cpp_api_sum Sum /// A primitive to sum data. /// /// @sa @ref c_api_sum in @ref c_api /// @{ struct sum : public primitive { struct primitive_desc : public handle<mkldnn_primitive_desc_t> { std::vector<mkldnn_memory_desc_t> cpp_to_c( const std::vector<memory::desc> &srcs) { std::vector<mkldnn_memory_desc_t> c_api_srcs; c_api_srcs.reserve(srcs.size()); for (const auto &s : srcs) c_api_srcs.push_back(s.data); return c_api_srcs; } primitive_desc(const memory::desc &dst, const std::vector<float> &scales, const std::vector<memory::desc> &srcs, const engine &aengine) { error::wrap_c_api(scales.size() == srcs.size() ? mkldnn_success : mkldnn_invalid_arguments, "number of scales not equal to number of srcs"); auto c_api_srcs = cpp_to_c(srcs); mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_sum_primitive_desc_create( &result, &dst.data, (int)c_api_srcs.size(), &scales[0], &c_api_srcs[0], nullptr, aengine.get()), "could not create a sum primitive descriptor"); reset(result); } primitive_desc(const std::vector<float> &scales, const std::vector<memory::desc> &srcs, const engine &aengine) { error::wrap_c_api(scales.size() == srcs.size() ? mkldnn_success : mkldnn_invalid_arguments, "number of scales not equal to number of srcs"); auto c_api_srcs = cpp_to_c(srcs); mkldnn_primitive_desc_t result; error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, nullptr, (int)c_api_srcs.size(), &scales[0], &c_api_srcs[0], nullptr, aengine.get()), "could not create a sum primitive descriptor"); reset(result); } memory::desc dst_desc() const { const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( get(), mkldnn::convert_to_c(dst_md), 0); error::wrap_c_api( cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, "could not get a dst memory descriptor"); return memory::desc(*cdesc); } memory::desc scratchpad_desc() const { const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( get(), mkldnn::convert_to_c(scratchpad_md), 0); if (cdesc == nullptr) return memory::desc(); return memory::desc(*cdesc); } engine get_engine() { return engine::query(*this); } }; sum(const primitive_desc &pd): primitive(pd.get()) {} }; /// @} /// @} /// @addtogroup cpp_api_primitives Primitives /// @{ /// @addtogroup cpp_api_primitive_descriptors Primitive descriptors /// @{ /// A base class for all primitive descriptors. struct primitive_desc : public handle<mkldnn_primitive_desc_t> { primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { mkldnn_primitive_desc_iterator_t iterator = nullptr; mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( &iterator, desc, attr ? attr->get() : nullptr, e.get(), hint_fwd_pd); error::wrap_c_api(status, "could not create a primitive descriptor iterator"); pd_iterator.reset(iterator); fetch_impl(); } engine get_engine() { return engine::query(*this); } primitive_attr get_primitive_attr() const { const_mkldnn_primitive_attr_t const_cattr; error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), "could not get attributes"); mkldnn_primitive_attr_t cattr; error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), "could not clone attributes"); primitive_attr attr; attr.reset(cattr); return attr; } /// Returns implementation name const char *impl_info_str() const { const char *res; error::wrap_c_api(mkldnn_primitive_desc_query(get(), mkldnn_query_impl_info_str, 0, &res), "could not query implementation info string"); return res; } /// Queries the memory::dim value (same as int64_t) memory::dim query_s64(query q) const { memory::dim res; mkldnn_status_t status = mkldnn_primitive_desc_query(get(), mkldnn::convert_to_c(q), 0, &res); return status == mkldnn_success ? res : 0; } /// Advances the next implementation for the given op descriptor. /// /// Returns: /// - @c true on success /// - @c false if the last implementation reached, and /// the primitive descriptor itself is kept unchanged bool next_impl() { mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( pd_iterator.get()); if (status == mkldnn_iterator_ends) return false; error::wrap_c_api(status, "primitive descriptor iterator next failed"); fetch_impl(); return true; } /// Queries and returns requested memory descriptor. memory::desc query_md(query what, int idx = 0) const { std::vector<query> valid_q{src_md, diff_src_md, weights_md, diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) { return what == q; })) throw error(mkldnn_invalid_arguments, "invalid memory query"); const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( get(), mkldnn::convert_to_c(what), idx); if (cdesc == nullptr) return memory::desc(); return memory::desc(*cdesc); } // register specialized queries, e.g. src_desc() # define REG_QUERY_MD(name, what, idx) \ memory::desc name ## _desc() const { return query_md(what ## _md, idx); } private: handle<mkldnn_primitive_desc_iterator_t> pd_iterator; void fetch_impl() { mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( pd_iterator.get()); error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, "could not fetch a primitive descriptor from the iterator"); reset(pd); } }; /// @} /// @addtogroup cpp_api_convolution Convolution /// A primitive to compute convolution using different algorithms. /// /// @sa @ref c_api_convolution in @ref c_api /// @{ struct convolution_forward: public primitive { struct desc { mkldnn_convolution_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated convolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated convolution forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(bias, weights, 1); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; convolution_forward(const primitive_desc &pd): primitive(pd) {} }; struct convolution_backward_data : public primitive { struct desc { mkldnn_convolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward data descriptor"); } desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api( mkldnn_dilated_convolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward data descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; convolution_backward_data(const primitive_desc &pd): primitive(pd) {} }; struct convolution_backward_weights : public primitive { struct desc { mkldnn_convolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a convolution backward weights descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(diff_weights, diff_weights, 0); REG_QUERY_MD(diff_bias, diff_weights, 1); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} }; /// @} // /// @addtogroup cpp_api_deconvolution Deconvolution /// A primitive to compute deconvolution using different algorithms. /// /// @sa @ref c_api_deconvolution in @ref c_api /// @{ struct deconvolution_forward: public primitive { struct desc { mkldnn_deconvolution_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated deconvolution forward descriptor"); } desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated deconvolution forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(bias, weights, 1); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; deconvolution_forward(const primitive_desc &pd): primitive(pd) {} }; struct deconvolution_backward_data : public primitive { struct desc { mkldnn_deconvolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution backward data descriptor"); } desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( &data, convert_to_c(aalgorithm), &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated deconvolution backward data descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} }; struct deconvolution_backward_weights : public primitive { struct desc { mkldnn_deconvolution_desc_t data; desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a deconvolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated deconvolution backward weights descriptor"); } desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(dilates); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( &data, convert_to_c(aalgorithm), &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data, &strides[0], &dilates[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not create a dilated deconvolution backward weights descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(diff_weights, diff_weights, 0); REG_QUERY_MD(diff_bias, diff_weights, 1); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_lrn LRN /// A primitive to perform local response normalization (LRN) across or within /// channels. /// /// @sa @ref c_api_lrn in @ref c_api /// @{ struct lrn_forward : public primitive { struct desc { mkldnn_lrn_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, memory::dim local_size, float alpha, float beta, float k = 1.f) { error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, local_size, alpha, beta, k), "could not create a lrn forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; lrn_forward(const primitive_desc &pd): primitive(pd) {} }; struct lrn_backward : public primitive { struct desc { mkldnn_lrn_desc_t data; desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k = 1.f) { error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, convert_to_c(aalgorithm), &diff_data_desc.data, &data_desc.data, local_size, alpha, beta, k), "could not create a lrn backward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; lrn_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_pooling Pooling /// A primitive to perform max or average pooling. /// /// @sa @ref c_api_pooling in @ref c_api /// @{ struct pooling_forward : public primitive { struct desc { mkldnn_pooling_desc_t data; desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(kernel); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), &src_desc.data, &dst_desc.data, &strides[0], &kernel[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not init a forward pooling descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; pooling_forward(const primitive_desc &pd): primitive(pd) {} }; struct pooling_backward : public primitive { struct desc { mkldnn_pooling_desc_t data; desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind) { memory::validate_dims(strides); memory::validate_dims(kernel); memory::validate_dims(padding_l); memory::validate_dims(padding_r); error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, convert_to_c(aalgorithm), &diff_src_desc.data, &diff_dst_desc.data, &strides[0], &kernel[0], &padding_l[0], &padding_r[0], mkldnn::convert_to_c(apadding_kind)), "could not init a backward pooling descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; pooling_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_eltwise Eltwise /// A primitive to compute element-wise operations like parametric rectifier /// linear unit (ReLU). /// /// @sa @ref c_api_eltwise in @ref c_api /// @{ struct eltwise_forward : public primitive { struct desc { mkldnn_eltwise_desc_t data; template <typename T> desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha = 0, T beta = 0) { error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), mkldnn::convert_to_c(alg_kind), &src_desc.data, static_cast<float>(alpha), static_cast<float>(beta)), "could not create a eltwise forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; eltwise_forward(const primitive_desc &pd): primitive(pd) {} }; struct eltwise_backward : public primitive { struct desc { mkldnn_eltwise_desc_t data; template <typename T> desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha = 0, T beta = 0) { error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, &data_desc.data, static_cast<float>(alpha), static_cast<float>(beta)), "could not create a eltwise backward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; eltwise_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_softmax Softmax /// A primitive to perform softmax. /// /// @sa @ref c_api_softmax in @ref c_api /// @{ struct softmax_forward : public primitive { struct desc { mkldnn_softmax_desc_t data; desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis) { error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &data_desc.data, softmax_axis), "could not create a softmax forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; softmax_forward(const primitive_desc &pd): primitive(pd) {} }; struct softmax_backward : public primitive { struct desc { mkldnn_softmax_desc_t data; desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis) { error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, &diff_desc.data, &data_desc.data, softmax_axis), "could not init a backward softmax descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; softmax_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_batch_norm Batch normalization /// A primitive to perform batch normalization. /// /// @sa @ref c_api_batch_normalization in @ref c_api /// @{ struct batch_normalization_forward : public primitive { struct desc { mkldnn_batch_normalization_desc_t data; template <typename T> desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags) { error::wrap_c_api( mkldnn_batch_normalization_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &src_desc.data, static_cast<float>(epsilon), flags), "could not create a batch normalization forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); memory::desc mean_desc() const { return stat_desc(mean); } memory::desc variance_desc() const { return stat_desc(var); } private: enum { mean = 1, var = 2, }; memory::desc stat_desc(int kind) const { mkldnn_batch_normalization_desc_t *p; error::wrap_c_api(mkldnn_primitive_desc_query( get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), "could not get a batch-normalization descriptor"); return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); } }; batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} }; struct batch_normalization_backward : public primitive { struct desc { mkldnn_batch_normalization_desc_t data; template <typename T> desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags) { error::wrap_c_api( mkldnn_batch_normalization_backward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &diff_data_desc.data, &data_desc.data, static_cast<float>(epsilon), flags), "could not create a batch normalization backward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(mean, src, 1); REG_QUERY_MD(variance, src, 2); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(diff_weights, diff_weights, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_inner_product Inner Product /// A primitive to compute an inner product. /// /// @sa @ref c_api_inner_product in @ref c_api /// @{ struct inner_product_forward: public primitive { struct desc { mkldnn_inner_product_desc_t data; desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc) { error::wrap_c_api( mkldnn_inner_product_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &src_desc.data, &weights_desc.data, &bias_desc.data, &dst_desc.data), "could not create a inner product forward descriptor"); } desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc) { error::wrap_c_api( mkldnn_inner_product_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &src_desc.data, &weights_desc.data, nullptr, &dst_desc.data), "could not create a inner product forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(bias, weights, 1); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; inner_product_forward(const primitive_desc &pd): primitive(pd) {} }; struct inner_product_backward_data: public primitive { struct desc { mkldnn_inner_product_desc_t data; desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc) { error::wrap_c_api( mkldnn_inner_product_backward_data_desc_init(&data, &diff_src_desc.data, &weights_desc.data, &diff_dst_desc.data), "could not create a inner product backward data descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(weights, weights, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} }; struct inner_product_backward_weights: public primitive { struct desc { mkldnn_inner_product_desc_t data; desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc) { error::wrap_c_api( mkldnn_inner_product_backward_weights_desc_init( &data, &src_desc.data, &diff_weights_desc.data, &diff_bias_desc.data, &diff_dst_desc.data), "could not create a inner product backward weights descriptor"); } desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc) { error::wrap_c_api( mkldnn_inner_product_backward_weights_desc_init( &data, &src_desc.data, &diff_weights_desc.data, nullptr, &diff_dst_desc.data), "could not create a inner product backward weights descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(diff_weights, diff_weights, 0); REG_QUERY_MD(diff_bias, diff_weights, 1); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_rnn RNN /// A primitive to compute common recurrent layer. /// /// @sa @ref c_api_rnn in @ref c_api /// @{ struct rnn_cell { struct desc { mkldnn_rnn_cell_desc_t c_rnn_cell_; desc(algorithm kind, algorithm activation_f) { error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, mkldnn::convert_to_c(kind), mkldnn::convert_to_c(activation_f), 0U, 0, 0), "could not init an rnn cell descriptor"); } desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } algorithm get_cell_kind() const { return algorithm(c_rnn_cell_.cell_kind); } algorithm get_activation() const { return algorithm(c_rnn_cell_.activation_kind); } float get_alpha() const { return c_rnn_cell_.alpha; } void set_alpha(float alpha) { c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; c_rnn_cell_.alpha = alpha; } float get_clipping() const { return c_rnn_cell_.clipping; } void set_clipping(float clipping) { c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; c_rnn_cell_.clipping = clipping; } int get_gates_count() const { return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); } int get_state_count() const { return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); } }; }; struct rnn_forward : public primitive { struct desc { mkldnn_rnn_desc_t data; desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc ) { error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), cell, mkldnn::convert_to_c(direction), &src_layer_desc.data, &src_iter_desc.data, &weights_layer_desc.data, &weights_iter_desc.data, &bias_desc.data, &dst_layer_desc.data, &dst_iter_desc.data), "could not create an RNN forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} REG_QUERY_MD(src_layer, src, 0); REG_QUERY_MD(src_iter, src, 1); REG_QUERY_MD(weights_layer, weights, 0); REG_QUERY_MD(weights_iter, weights, 1); REG_QUERY_MD(bias, weights, 2); REG_QUERY_MD(dst_layer, dst, 0); REG_QUERY_MD(dst_iter, dst, 1); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; rnn_forward(const primitive_desc &pd): primitive(pd) {} }; struct rnn_backward : public primitive { struct desc { mkldnn_rnn_desc_t data; desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc) { error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), cell, mkldnn::convert_to_c(direction), &src_layer_desc.data, &src_iter_desc.data, &weights_layer_desc.data, &weights_iter_desc.data, &bias_desc.data, &dst_layer_desc.data, &dst_iter_desc.data, &diff_src_layer_desc.data, &diff_src_iter_desc.data, &diff_weights_layer_desc.data, &diff_weights_iter_desc.data, &diff_bias_desc.data, &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), "could not create an RNN backward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(src_layer, src, 0); REG_QUERY_MD(src_iter, src, 1); REG_QUERY_MD(weights_layer, weights, 0); REG_QUERY_MD(weights_iter, weights, 1); REG_QUERY_MD(bias, weights, 2); REG_QUERY_MD(dst_layer, dst, 0); REG_QUERY_MD(dst_iter, dst, 1); REG_QUERY_MD(workspace, workspace, 0); REG_QUERY_MD(diff_src_layer, diff_src, 0); REG_QUERY_MD(diff_src_iter, diff_src, 1); REG_QUERY_MD(diff_weights_layer, diff_weights, 0); REG_QUERY_MD(diff_weights_iter, diff_weights, 1); REG_QUERY_MD(diff_bias, diff_weights, 2); REG_QUERY_MD(diff_dst_layer, diff_dst, 0); REG_QUERY_MD(diff_dst_iter, diff_dst, 1); REG_QUERY_MD(scratchpad, scratchpad, 0); }; // With last iteration (with and without input src_iter) rnn_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @addtogroup cpp_api_shuffle Shuffle /// A primitive to shuffle data along the axis. /// /// @sa @ref c_api_shuffle in @ref c_api /// @{ struct shuffle_forward : public primitive { struct desc { mkldnn_shuffle_desc_t data; desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size) { error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, mkldnn::convert_to_c(aprop_kind), &data_desc.data, axis, group_size), "could not create a shuffle forward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e) : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} REG_QUERY_MD(src, src, 0); REG_QUERY_MD(dst, dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; shuffle_forward(const primitive_desc &pd): primitive(pd) {} }; struct shuffle_backward : public primitive { struct desc { mkldnn_shuffle_desc_t data; desc(const memory::desc &diff_data_desc, int axis, int group_size) { error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, &diff_data_desc.data, axis, group_size), "could not create a shuffle backward descriptor"); } }; struct primitive_desc : public mkldnn::primitive_desc { primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd) : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} REG_QUERY_MD(diff_src, diff_src, 0); REG_QUERY_MD(diff_dst, diff_dst, 0); REG_QUERY_MD(scratchpad, scratchpad, 0); }; shuffle_backward(const primitive_desc &pd): primitive(pd) {} }; /// @} /// @} Primitives /// @} C++ API #undef REG_QUERY_MD // implementation section #ifndef DOXYGEN_SHOULD_SKIP_THIS inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { mkldnn_primitive_t result; error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), "could not create a primitive"); reset(result); } inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} inline void primitive::execute(stream &astream, const std::unordered_map<int, memory> &args) const { std::vector<mkldnn_exec_arg_t> c_args; c_args.reserve(args.size()); for (const auto &a: args) c_args.push_back({a.first, a.second.get()}); error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), (int)c_args.size(), c_args.data()), "primitive execution fail"); } #endif // DOXYGEN_SHOULD_SKIP_THIS } // namespace mkldnn #endif