summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp266
1 files changed, 266 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
new file mode 100644
index 0000000000..9e77b890d5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
@@ -0,0 +1,266 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP
+#define GEMM_X8S8S32X_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_primitive_conf.hpp"
+#include "jit_generator.hpp"
+#include "gemm_convolution_utils.hpp"
+
+#include "gemm/gemm.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t src_type, data_type_t dst_type>
+struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
+ _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, s8, data_type::undef, dst_type,
+ s32)
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, f32, s32, s8, u8))
+ && !has_zero_dim_memory()
+ && set_default_formats_common(
+ dat_tag(), format_tag::any, dat_tag())
+ && post_ops_ok()
+ && memory_desc_matches_tag(*src_md(), dat_tag())
+ && memory_desc_matches_tag(*dst_md(), dat_tag())
+ && set_or_check_wei_format();
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), src_md(), weights_md(0), dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const { return format_tag::nhwc; }
+
+ bool set_or_check_wei_format() {
+ using namespace format_tag;
+
+ const bool is_src_s8 = src_md_.data_type == data_type::s8;
+
+ memory_desc_t want_wei_md = weights_md_;
+ memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio);
+
+ if (is_src_s8) {
+ want_wei_md.extra.flags = 0
+ | memory_extra_flags::compensation_conv_s8s8
+ | memory_extra_flags::scale_adjust;
+ want_wei_md.extra.compensation_mask = (1 << 0)
+ + (with_groups() ? (1 << 1) : 0);
+ want_wei_md.extra.scale_adjust =
+ mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
+ }
+
+ if (weights_md_.format_kind == format_kind::any) {
+ weights_md_ = want_wei_md;
+ return true;
+ }
+
+ return weights_md_ == want_wei_md;
+ }
+
+ bool post_ops_ok() const {
+ using namespace mkldnn::impl::primitive_kind;
+ auto const &po = attr()->post_ops_;
+ auto is_relu = [&](int idx) {
+ return po.entry_[idx].is_relu(true, false); };
+
+ switch (po.len_) {
+ case 0: return true;
+ case 1: return is_relu(0) || po.contain(sum, 0);
+ case 2: return po.contain(sum, 0) && is_relu(1);
+ default: return false;
+ }
+ return false;
+ }
+ };
+
+ _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true), pp_ker_(nullptr)
+ { pp_ker_ = new pp_ker_t(pd()); }
+ ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ // XXX: this is throwaway code that will become unnecessary when we have a
+ // sufficiently advanced igemm jit generator that supports quantization,
+ // relu, and whatnot
+ class pp_ker_t : jit_generator {
+ public:
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ _gemm_x8s8s32x_convolution_fwd_t::pp_kernel);
+ pp_ker_t(const pd_t *pd);
+
+ void operator()(dst_data_t *dst, const acc_data_t *acc,
+ const char *bias, const float *scales,
+ float nslope, float sum_scale, float signed_scale,
+ int g, size_t start, size_t end);
+
+ size_t dst_os_stride_;
+
+ private:
+ void generate();
+
+ struct ker_args {
+ dst_data_t *dst;
+ const acc_data_t *acc;
+ const char *bias;
+ const float *scales;
+ float nslope;
+ float sum_scale;
+ float signed_scale;
+ size_t len;
+ size_t oc_offset;
+ };
+ void(*ker_)(const ker_args *args);
+
+ const jit_gemm_conv_conf_t &jcp_;
+ size_t OC_;
+ size_t OS_;
+ data_type_t bias_data_type_;
+ size_t bias_data_type_size_;
+ size_t scale_idx_mult_;
+ bool do_bias_;
+ bool do_relu_;
+ bool do_sum_;
+ bool do_signed_scaling_;
+ size_t vlen_;
+ };
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ void execute_forward(const exec_ctx_t &ctx) const;
+ void execute_forward_thr(const int ithr, const int nthr,
+ const src_data_t *src_base, const wei_data_t *wei_base,
+ const char *bia_base, dst_data_t *dst_base,
+ const memory_tracking::grantor_t &scratchpad) const;
+
+ int nthr_;
+ pp_ker_t *pp_ker_;
+
+};
+
+template <data_type_t dst_type>
+struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t{
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc, const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
+ _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(dst_type, s8, data_type::undef, u8, s32)
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, f32, s32, s8, u8))
+ && !has_zero_dim_memory()
+ && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
+ && attr()->post_ops_.has_default_values()
+ && memory_desc_matches_tag(*diff_src_md(), dat_tag())
+ && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
+ && memory_desc_matches_tag(*weights_md(), wei_tag());
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), diff_src_md(), weights_md(), diff_dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ virtual bool support_bias() const override { return true; }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const { return format_tag::nhwc; }
+
+ format_tag_t wei_tag() const {
+ return with_groups() ? format_tag::hwigo : format_tag::hwio;
+ }
+ };
+
+ _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true) {}
+
+ typedef typename prec_traits<data_type::u8>::type diff_dst_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type diff_src_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ void execute_backward_data_thr(const int ithr, const int nthr,
+ const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
+ const char *bia_base, diff_src_data_t *diff_src_base,
+ const memory_tracking::grantor_t &scratchpad) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif