summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
blob: 51772754523a9a2b70f5bcfeaa6b8a34b95c5ede (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef SIMPLE_CONCAT_HPP
#define SIMPLE_CONCAT_HPP

#include "memory_tracking.hpp"

#include "cpu_concat_pd.hpp"
#include "cpu_primitive.hpp"

namespace mkldnn {
namespace impl {
namespace cpu {

template <data_type_t data_type>
struct simple_concat_t: public cpu_primitive_t {
    struct pd_t: public cpu_concat_pd_t {
        using cpu_concat_pd_t::cpu_concat_pd_t;

        pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
            int ndims = rhs.dst_md_.ndims;
            utils::array_copy(perm_, rhs.perm_, ndims);
            utils::array_copy(iperm_, rhs.iperm_, ndims);
            utils::array_copy(blocks_, rhs.blocks_, ndims);
        }

        DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);

        status_t init() {
            const memory_desc_wrapper dst_d(dst_md());
            bool ok = true
                && cpu_concat_pd_t::init() == status::success
                && dst_d.ndims() <= 6;
            if (!ok) return status::unimplemented;

            for (size_t i = 0; i < src_mds_.size(); ++i) {
                const memory_desc_wrapper i_d(&src_mds_[i]);
                const memory_desc_wrapper o_d(&src_image_mds_[i]);

                const int ignore_strides = 0;

                ok = ok
                    && utils::everyone_is(data_type, i_d.data_type(),
                            o_d.data_type())
                    && utils::everyone_is(format_kind::blocked,
                            i_d.format_kind(), o_d.format_kind())
                    && types::blocking_desc_is_equal(i_d.blocking_desc(),
                            o_d.blocking_desc(), ignore_strides)
                    && types::blocking_desc_is_equal(i_d.blocking_desc(),
                            dst_d.blocking_desc(), ignore_strides)
                    && !i_d.is_additional_buffer();
                if (!ok) return status::unimplemented;
            }

            dst_d.compute_blocks(blocks_);
            format_perm();

            // start dim is the first dimension after which the concatenation
            // would happen contiguously
            const int start_dim = perm_[concat_dim()];

            // check that contiguous part is indeed contiguous (i.e. dense)
            if (nelems_to_concat(dst_d) !=
                    dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
                    * dst_d.blocking_desc().strides[concat_dim()])
                return status::unimplemented;

            // check that all inputs have the same strides for the
            // contiguous part [concat_dim .. ndims] for the *major* dims.
            // the block part is already checked above
            for (size_t i = 0; i < src_mds_.size(); ++i) {
                const memory_desc_wrapper i_d(&src_mds_[i]);
                for (int d = start_dim; d < dst_d.ndims(); ++d) {
                    if (dst_d.blocking_desc().strides[iperm_[d]]
                            != i_d.blocking_desc().strides[iperm_[d]])
                        return status::unimplemented;
                }
            }

            init_scratchpad();

            return status::success;
        }

        int perm_[MKLDNN_MAX_NDIMS];
        int iperm_[MKLDNN_MAX_NDIMS];
        dims_t blocks_;

        dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
            const int ndims = data_d.ndims();

            dim_t nelems = 1;
            for (int i = perm_[concat_dim()]; i < ndims; i++)
                nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]];
            for (int i = 0; i < ndims; i++)
                nelems *= blocks_[i];

            return nelems;
        }

    private:
        void format_perm() {
            const memory_desc_wrapper dst_d(dst_md());
            const int ndims = dst_d.ndims();

            strides_t strides;
            utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
            for (int i = 0; i < ndims; i++) iperm_[i] = i;

            utils::simultaneous_sort(strides, iperm_, ndims,
                    [](stride_t a, stride_t b) { return b - a; });

            for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
        }

        void init_scratchpad() {
            using namespace memory_tracking::names;
            auto scratchpad = scratchpad_registry().registrar();
            scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
            scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
            scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs());
            scratchpad.book(key_concat_istrides,
                    sizeof(strides_t) * n_inputs());
        }
    };

    simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {}

    virtual status_t execute(const exec_ctx_t &ctx) const override;

    typedef typename prec_traits<data_type>::type data_t;

private:
    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
};

}
}
}

#endif