diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp new file mode 100644 index 0000000000..9fd638842c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +namespace mkldnn { +namespace impl { + +status_t scales_t::set(dim_t count, int mask, const float *scales) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + scales_ = scales_buf_; + utils::array_set(scales_, scales[0], scales_buf_size); + } else { + scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); + if (scales_ == nullptr) + return status::out_of_memory; + + for (dim_t c = 0; c < count_; ++c) + scales_[c] = scales[c]; + } + + return status::success; +} + +} +} + +status_t post_ops_t::append_sum(float scale) { + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::sum; + entry_[len_].sum.scale = scale; + + len_++; + + return success; +} + +status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, + float beta) { + using namespace mkldnn::impl::alg_kind; + bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); + if (!known_alg) + return invalid_arguments; + + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::eltwise; + entry_[len_].eltwise.scale = scale; + entry_[len_].eltwise.alg = alg; + entry_[len_].eltwise.alpha = alpha; + entry_[len_].eltwise.beta = beta; + + len_++; + + return success; +} + +status_t primitive_attr_t::set_scratchpad_mode( + scratchpad_mode_t scratchpad_mode) { + using namespace mkldnn::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, library, user); + if (!ok) + return invalid_arguments; + + scratchpad_mode_ = scratchpad_mode; + return success; +} + +status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { + this->post_ops_ = post_ops; + return success; +} + +/* Public C API */ + +status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { + if (attr == nullptr) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_primitive_attr>(*attr, + new mkldnn_primitive_attr); +} + +status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, + const primitive_attr_t *existing_attr) { + if (any_null(attr, existing_attr)) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_primitive_attr>(*attr, + existing_attr->clone()); +} + +status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { + if (attr) + delete attr; + + return success; +} + +status_t mkldnn_primitive_attr_get_scratchpad_mode( + const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { + if (any_null(attr, scratchpad_mode)) + return invalid_arguments; + + *scratchpad_mode = attr->scratchpad_mode_; + + return success; +} + +status_t mkldnn_primitive_attr_set_scratchpad_mode( + primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { + if (any_null(attr)) + return invalid_arguments; + + return attr->set_scratchpad_mode(scratchpad_mode); +} + +status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, + dim_t *count, int *mask, const float **scales) { + if (any_null(attr, count, mask, scales)) + return invalid_arguments; + + *count = attr->output_scales_.count_; + *mask = attr->output_scales_.mask_; + *scales = attr->output_scales_.scales_; + + return success; +} + +status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, + dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_scales_.set(count, mask, scales); +} + +status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, + const post_ops_t **post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + *post_ops = &attr->post_ops_; + return success; +} + +status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, + const post_ops_t *post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + return attr->set_post_ops(*post_ops); +} + +status_t mkldnn_post_ops_create(post_ops_t **post_ops) { + if (post_ops == nullptr) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops); +} + +status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { + if (post_ops) + delete post_ops; + + return success; +} + +int mkldnn_post_ops_len(const post_ops_t *post_ops) { + if (post_ops) + return post_ops->len_; + + return 0; +} + +primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, + int index) { + bool ok = post_ops && 0 <= index && index < post_ops->len_; + if (!ok) + return primitive_kind::undefined; + + return post_ops->entry_[index].kind; +} + +status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_sum(scale); +} + +namespace { +bool simple_get_params_check(const post_ops_t *post_ops, int index, + primitive_kind_t kind) { + bool ok = true + && post_ops != nullptr + && 0 <= index + && index < post_ops->len_ + && post_ops->entry_[index].kind == kind; + return ok; +} +} + +status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, + float *scale) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::sum) + && !any_null(scale); + if (!ok) + return invalid_arguments; + + *scale = post_ops->entry_[index].sum.scale; + return success; +} + +status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, + alg_kind_t kind, float alpha, float beta) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_eltwise(scale, kind, alpha, beta); +} + +status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, + int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::eltwise) + && !any_null(scale, alpha, beta); + if (!ok) + return invalid_arguments; + + const auto &e = post_ops->entry_[index].eltwise; + *scale = e.scale; + *alg = e.alg; + *alpha = e.alpha; + *beta = e.beta; + + return success; +} + +status_t mkldnn_primitive_attr_set_rnn_data_qparams( + primitive_attr_t *attr, const float scale, const float shift) { + if (attr == nullptr) + return invalid_arguments; + + return attr->rnn_data_qparams_.set(scale, shift); +} + +status_t mkldnn_primitive_attr_set_rnn_weights_qparams( + primitive_attr_t *attr, dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->rnn_weights_qparams_.set(count, mask, scales); +} |