summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp183
1 files changed, 183 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
new file mode 100644
index 0000000000..e2130c7ab1
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
@@ -0,0 +1,183 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_ATTR_HPP
+#define PRIMITIVE_ATTR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_data_qparams_t : public c_compatible {
+ rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
+ bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
+
+ status_t set(float scale, float shift) {
+ scale_ = scale;
+ shift_ = shift;
+ return status::success;
+ }
+
+ float scale_;
+ float shift_;
+};
+
+struct scales_t: public c_compatible {
+ scales_t(): count_(1), mask_(0), scales_(scales_buf_)
+ { set(1.); }
+
+ scales_t(const scales_t &rhs): scales_t()
+ { set(rhs.count_, rhs.mask_, rhs.scales_); }
+
+ ~scales_t() { cleanup(); }
+
+ scales_t &operator=(const scales_t &rhs) {
+ if (&rhs == this)
+ return *this;
+ status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
+ assert(status == status::success);
+ (void)status;
+ return *this;
+ }
+
+ bool has_default_values() const {
+ for (dim_t c = 0; c < count_; ++c) {
+ if(scales_[c] != 1.) return false;
+ }
+ return true;
+ }
+
+ status_t set(dim_t count, int mask, const float *scales);
+ status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
+
+ dim_t count_;
+ int mask_;
+ float *scales_;
+
+private:
+ enum { scales_buf_size = 16 };
+ float scales_buf_[scales_buf_size];
+
+ void cleanup() {
+ if (scales_ != scales_buf_ && scales_ != nullptr)
+ impl::free(scales_);
+
+ count_ = 1;
+ mask_ = 0;
+ scales_ = scales_buf_;
+ }
+};
+
+}
+}
+
+struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
+ struct entry_t {
+ struct eltwise_t {
+ mkldnn::impl::alg_kind_t alg;
+ float scale, alpha, beta;
+ };
+
+ mkldnn::impl::primitive_kind_t kind;
+ union {
+ struct { float scale; } sum;
+ eltwise_t eltwise;
+ };
+
+ bool is_eltwise(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::eltwise
+ && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
+ }
+
+ bool is_relu(bool require_scale_one = true,
+ bool require_nslope_zero = true) const {
+ using namespace mkldnn::impl;
+ return is_eltwise(require_scale_one)
+ && eltwise.alg == alg_kind::eltwise_relu
+ && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
+ }
+
+ bool is_sum(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::sum
+ && IMPLICATION(require_scale_one, sum.scale == 1.f);
+ }
+ };
+
+ mkldnn_post_ops(): len_(0) {}
+
+ mkldnn::impl::status_t append_sum(float scale);
+ mkldnn::impl::status_t append_eltwise(float scale,
+ mkldnn::impl::alg_kind_t alg, float alpha, float beta);
+
+ int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
+ int stop = -1) const {
+ if (stop == -1) stop = len_;
+ stop = mkldnn::impl::nstl::min(stop, len_);
+ for (int idx = start; idx < stop; ++idx)
+ if (entry_[idx].kind == kind) return idx;
+ return -1;
+ }
+
+ bool has_default_values() const { return len_ == 0; }
+
+ bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
+ { return find(kind, index, index + 1) == index; }
+
+ enum { capacity = 4 };
+
+ int len_;
+ entry_t entry_[capacity];
+};
+
+struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
+ mkldnn_primitive_attr()
+ : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
+ {}
+
+ mkldnn_primitive_attr *clone() const
+ { return new mkldnn_primitive_attr(*this); }
+
+ /** Returns true if the attributes have default values.
+ *
+ * @note The scratchpad_mode_ is not take into account */
+ bool has_default_values() const {
+ return true
+ && output_scales_.has_default_values()
+ && post_ops_.has_default_values()
+ && rnn_data_qparams_.has_default_values()
+ && rnn_weights_qparams_.has_default_values();
+ }
+
+ mkldnn::impl::status_t set_scratchpad_mode(
+ mkldnn::impl::scratchpad_mode_t scratchpad_mode);
+ mkldnn::impl::status_t set_post_ops(
+ const mkldnn::impl::post_ops_t &post_ops);
+
+ mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
+ mkldnn::impl::scales_t output_scales_;
+ mkldnn::impl::post_ops_t post_ops_;
+ mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
+ mkldnn::impl::scales_t rnn_weights_qparams_;
+};
+
+#endif