1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_SIMPLE_CONCAT_HPP
18#define CPU_SIMPLE_CONCAT_HPP
19
20#include "common/memory_tracking.hpp"
21#include "common/primitive.hpp"
22
23#include "cpu/platform.hpp"
24
25#include "cpu/cpu_concat_pd.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31template <data_type_t data_type>
32struct simple_concat_t : public primitive_t {
33 struct pd_t : public cpu_concat_pd_t {
34 using cpu_concat_pd_t::cpu_concat_pd_t;
35
36 pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) { copy_from(rhs); }
37
38 DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);
39
40 status_t init(engine_t *engine) {
41 const memory_desc_wrapper dst_d(dst_md());
42 bool ok = platform::has_data_type_support(data_type)
43 && attr()->has_default_values()
44 && cpu_concat_pd_t::init() == status::success
45 && dst_d.ndims() <= 6;
46 if (!ok) return status::unimplemented;
47
48 for (size_t i = 0; i < src_mds_.size(); ++i) {
49 const memory_desc_wrapper i_d(&src_mds_[i]);
50 const memory_desc_wrapper o_d(&src_image_mds_[i]);
51
52 const bool ignore_strides = true;
53
54 ok = ok
55 && utils::everyone_is(
56 data_type, i_d.data_type(), o_d.data_type())
57 && utils::everyone_is(format_kind::blocked,
58 i_d.format_kind(), o_d.format_kind())
59 && types::blocking_desc_is_equal(
60 *i_d.md_, *o_d.md_, ignore_strides)
61 && types::blocking_desc_is_equal(
62 *i_d.md_, *dst_d.md_, ignore_strides)
63 && !i_d.is_additional_buffer();
64 if (!ok) return status::unimplemented;
65 }
66
67 dst_d.compute_blocks(blocks_);
68 format_perm();
69
70 // start dim is the first dimension after which the concatenation
71 // would happen contiguously
72 const int start_dim = perm_[concat_dim()];
73
74 // check that contiguous part is indeed contiguous (i.e. dense)
75 if (nelems_to_concat(dst_d)
76 != dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
77 * dst_d.blocking_desc().strides[concat_dim()])
78 return status::unimplemented;
79
80 // check that all inputs have the same strides for the
81 // contiguous part [concat_dim .. ndims] for the *major* dims.
82 // the block part is already checked above
83 for (size_t i = 0; i < src_mds_.size(); ++i) {
84 const memory_desc_wrapper i_d(&src_mds_[i]);
85 for (int d = start_dim; d < dst_d.ndims(); ++d) {
86 if (dst_d.blocking_desc().strides[iperm_[d]]
87 != i_d.blocking_desc().strides[iperm_[d]])
88 return status::unimplemented;
89 }
90 }
91
92 init_scratchpad();
93
94 return status::success;
95 }
96
97 int perm_[DNNL_MAX_NDIMS] {};
98 int iperm_[DNNL_MAX_NDIMS] {};
99 dims_t blocks_ {};
100
101 dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
102 const int ndims = data_d.ndims();
103
104 dim_t nelems = 1;
105 for (int i = perm_[concat_dim()]; i < ndims; i++)
106 nelems *= data_d.padded_dims()[iperm_[i]] / blocks_[iperm_[i]];
107 for (int i = 0; i < ndims; i++)
108 nelems *= blocks_[i];
109
110 return nelems;
111 }
112
113 private:
114 void format_perm() {
115 const memory_desc_wrapper dst_d(dst_md());
116 const int ndims = dst_d.ndims();
117
118 dims_t blocks = {0};
119 dst_d.compute_blocks(blocks);
120
121 strides_t strides = {0};
122 utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
123
124 dims_t ou_blocks = {0};
125 utils::array_copy(ou_blocks, dst_d.padded_dims(), ndims);
126
127 for (int d = 0; d < ndims; d++) {
128 iperm_[d] = d;
129 ou_blocks[d] /= blocks[d];
130 }
131
132 utils::simultaneous_sort(strides, ou_blocks, iperm_, ndims,
133 [](stride_t a, stride_t b) { return b - a; });
134
135 for (int i = 0; i < ndims; i++)
136 perm_[iperm_[i]] = i;
137 }
138
139 void init_scratchpad() {
140 using namespace memory_tracking::names;
141 auto scratchpad = scratchpad_registry().registrar();
142 scratchpad.template book<data_t *>(key_concat_iptrs, n_inputs());
143 scratchpad.template book<data_t *>(key_concat_optrs, n_inputs());
144 scratchpad.template book<dim_t>(key_concat_nelems, n_inputs());
145 scratchpad.template book<strides_t>(
146 key_concat_istrides, n_inputs());
147 }
148
149 void copy_from(const pd_t &rhs) {
150 int ndims = rhs.dst_md_.ndims;
151 utils::array_copy(perm_, rhs.perm_, ndims);
152 utils::array_copy(iperm_, rhs.iperm_, ndims);
153 utils::array_copy(blocks_, rhs.blocks_, ndims);
154 }
155 };
156
157 simple_concat_t(const pd_t *apd) : primitive_t(apd) {}
158
159 status_t execute(const exec_ctx_t &ctx) const override;
160
161 typedef typename prec_traits<data_type>::type data_t;
162
163private:
164 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
165};
166
167} // namespace cpu
168} // namespace impl
169} // namespace dnnl
170
171#endif
172