summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp155
1 files changed, 155 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
new file mode 100644
index 0000000000..5177275452
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
@@ -0,0 +1,155 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SIMPLE_CONCAT_HPP
+#define SIMPLE_CONCAT_HPP
+
+#include "memory_tracking.hpp"
+
+#include "cpu_concat_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t data_type>
+struct simple_concat_t: public cpu_primitive_t {
+ struct pd_t: public cpu_concat_pd_t {
+ using cpu_concat_pd_t::cpu_concat_pd_t;
+
+ pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
+ int ndims = rhs.dst_md_.ndims;
+ utils::array_copy(perm_, rhs.perm_, ndims);
+ utils::array_copy(iperm_, rhs.iperm_, ndims);
+ utils::array_copy(blocks_, rhs.blocks_, ndims);
+ }
+
+ DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);
+
+ status_t init() {
+ const memory_desc_wrapper dst_d(dst_md());
+ bool ok = true
+ && cpu_concat_pd_t::init() == status::success
+ && dst_d.ndims() <= 6;
+ if (!ok) return status::unimplemented;
+
+ for (size_t i = 0; i < src_mds_.size(); ++i) {
+ const memory_desc_wrapper i_d(&src_mds_[i]);
+ const memory_desc_wrapper o_d(&src_image_mds_[i]);
+
+ const int ignore_strides = 0;
+
+ ok = ok
+ && utils::everyone_is(data_type, i_d.data_type(),
+ o_d.data_type())
+ && utils::everyone_is(format_kind::blocked,
+ i_d.format_kind(), o_d.format_kind())
+ && types::blocking_desc_is_equal(i_d.blocking_desc(),
+ o_d.blocking_desc(), ignore_strides)
+ && types::blocking_desc_is_equal(i_d.blocking_desc(),
+ dst_d.blocking_desc(), ignore_strides)
+ && !i_d.is_additional_buffer();
+ if (!ok) return status::unimplemented;
+ }
+
+ dst_d.compute_blocks(blocks_);
+ format_perm();
+
+ // start dim is the first dimension after which the concatenation
+ // would happen contiguously
+ const int start_dim = perm_[concat_dim()];
+
+ // check that contiguous part is indeed contiguous (i.e. dense)
+ if (nelems_to_concat(dst_d) !=
+ dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
+ * dst_d.blocking_desc().strides[concat_dim()])
+ return status::unimplemented;
+
+ // check that all inputs have the same strides for the
+ // contiguous part [concat_dim .. ndims] for the *major* dims.
+ // the block part is already checked above
+ for (size_t i = 0; i < src_mds_.size(); ++i) {
+ const memory_desc_wrapper i_d(&src_mds_[i]);
+ for (int d = start_dim; d < dst_d.ndims(); ++d) {
+ if (dst_d.blocking_desc().strides[iperm_[d]]
+ != i_d.blocking_desc().strides[iperm_[d]])
+ return status::unimplemented;
+ }
+ }
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ int perm_[MKLDNN_MAX_NDIMS];
+ int iperm_[MKLDNN_MAX_NDIMS];
+ dims_t blocks_;
+
+ dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
+ const int ndims = data_d.ndims();
+
+ dim_t nelems = 1;
+ for (int i = perm_[concat_dim()]; i < ndims; i++)
+ nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]];
+ for (int i = 0; i < ndims; i++)
+ nelems *= blocks_[i];
+
+ return nelems;
+ }
+
+ private:
+ void format_perm() {
+ const memory_desc_wrapper dst_d(dst_md());
+ const int ndims = dst_d.ndims();
+
+ strides_t strides;
+ utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
+ for (int i = 0; i < ndims; i++) iperm_[i] = i;
+
+ utils::simultaneous_sort(strides, iperm_, ndims,
+ [](stride_t a, stride_t b) { return b - a; });
+
+ for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
+ }
+
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
+ scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
+ scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs());
+ scratchpad.book(key_concat_istrides,
+ sizeof(strides_t) * n_inputs());
+ }
+ };
+
+ simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override;
+
+ typedef typename prec_traits<data_type>::type data_t;
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif