1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_CONCAT_PD_HPP
18#define COMMON_CONCAT_PD_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "primitive_desc.hpp"
24#include "type_helpers.hpp"
25
26#include "utils.hpp"
27
28namespace dnnl {
29namespace impl {
30
31struct concat_pd_t : public primitive_desc_t {
32 const concat_desc_t *desc() const { return &desc_; }
33 const op_desc_t *op_desc() const override {
34 return reinterpret_cast<const op_desc_t *>(this->desc());
35 }
36
37 arg_usage_t arg_usage(int arg) const override {
38 if (arg >= DNNL_ARG_MULTIPLE_SRC
39 && arg < DNNL_ARG_MULTIPLE_SRC + n_inputs())
40 return arg_usage_t::input;
41
42 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
43
44 return primitive_desc_t::arg_usage(arg);
45 }
46
47 const memory_desc_t *arg_md(int arg) const override {
48 int src_index = arg - DNNL_ARG_MULTIPLE_SRC;
49 if (src_index >= 0 && src_index < n_inputs()) return src_md(src_index);
50 if (arg == DNNL_ARG_DST) return dst_md(0);
51 return primitive_desc_t::arg_md(arg);
52 }
53
54 const memory_desc_t *src_md(int index = 0) const override {
55 return index < n_inputs() ? &src_mds_[index] : &glob_zero_md;
56 }
57 const memory_desc_t *dst_md(int index = 0) const override {
58 return index == 0 ? &dst_md_ : &glob_zero_md;
59 }
60
61 int n_inputs() const override { return n_; }
62 int n_outputs() const override { return 1; }
63
64 int concat_dim() const { return concat_dim_; }
65
66 const memory_desc_t *src_image_md(int index = 0) const {
67 return index < n_inputs() ? &src_image_mds_[index] : &glob_zero_md;
68 }
69
70protected:
71 int n_, concat_dim_;
72 memory_desc_t dst_md_;
73 memory_desc_t original_dst_;
74 std::vector<memory_desc_t> src_mds_;
75
76 /* contains images of srcs in the dst memory (if possible)
77 * Lives here to simplify some implementations. An implementation might
78 * use this auxiliary array iff init() returned success */
79 std::vector<memory_desc_t> src_image_mds_;
80
81protected:
82 concat_desc_t desc_;
83
84 concat_pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md,
85 int n, int concat_dim, const memory_desc_t *const *src_mds)
86 : primitive_desc_t(attr, primitive_kind::concat)
87 , n_(n)
88 , concat_dim_(concat_dim)
89 , dst_md_(*dst_md)
90 , original_dst_(*dst_md) {
91 src_mds_.reserve(n_);
92 for (int i = 0; i < n_; ++i)
93 src_mds_.push_back(*src_mds[i]);
94
95 init_desc();
96 }
97
98 concat_pd_t(const concat_pd_t &other) : primitive_desc_t(other) {
99 n_ = other.n_;
100 concat_dim_ = other.concat_dim_;
101 dst_md_ = other.dst_md_;
102 original_dst_ = other.original_dst_;
103 src_mds_ = other.src_mds_;
104 src_image_mds_ = other.src_image_mds_;
105
106 init_desc();
107 }
108
109 /* inits src_image_mds_ and dst_md_ in simple cases. It is possible to
110 * override dst_md_ by using force_dst_md.
111 * Rationale: if user forces particular dst_md, that cannot be used to
112 * create src_img_mds, the implementation might need to use
113 * intermediate (force_dst_md) memory with some plain format.
114 *
115 * @warning The call may fail. */
116 status_t init(const memory_desc_t *force_dst_md = nullptr) {
117 bool ok = true;
118 if (force_dst_md == nullptr)
119 ok = ok && set_default_params() == status::success;
120 if (!ok) return status::unimplemented;
121
122 /* work with force_dst_md */
123 if (force_dst_md == nullptr) force_dst_md = &dst_md_;
124
125 for (int i = 0; i < n_; ++i) {
126 const memory_desc_wrapper i_d(&src_mds_[i]);
127 if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
128 return status::unimplemented;
129 }
130
131 const int ndims = force_dst_md->ndims;
132 int current_concat_dim_offset = 0;
133 for (int i = 0; i < n_; ++i) {
134 const int dim = src_mds_[i].dims[concat_dim_];
135 dims_t dims, offsets = {};
136 utils::array_copy(dims, force_dst_md->dims, ndims);
137 dims[concat_dim_] = dim;
138 offsets[concat_dim_] = current_concat_dim_offset;
139
140 memory_desc_t src_img_d;
141 status_t status = memory_desc_init_submemory(
142 src_img_d, *force_dst_md, dims, offsets);
143 if (status != status::success) {
144 src_image_mds_.clear();
145 return status;
146 }
147 src_image_mds_.push_back(src_img_d);
148 current_concat_dim_offset += dim;
149 }
150
151 return status::success;
152 }
153
154 status_t set_default_params() {
155 if (dst_md_.format_kind != format_kind::any) return status::success;
156
157 const int ndims = dst_md_.ndims;
158
159 /* The stupidest ever heuristics (but not the same as we had before):
160 * - Pick the first non-plain format;
161 * - If all formats are plain or it is not possible to create a
162 * blocked format for the output, pick the format of the plain input
163 * - If this fails as well, use plain layout (abcd...)
164 */
165 status_t status = status::unimplemented;
166 for (int i = 0; i < n_; ++i) {
167 const memory_desc_wrapper src_d(src_mds_[i]);
168 if (src_d.is_blocking_desc() && !src_d.is_plain()) {
169 status = memory_desc_init_by_blocking_desc(
170 dst_md_, src_d.blocking_desc());
171 if (status == status::success) break;
172 }
173 }
174
175 if (status == status::success) {
176 /* check if we can create a sub-memory for the dst */
177 bool desired_format_ok = true;
178 dims_t dims {}, offsets {};
179 utils::array_copy(dims, dst_md_.dims, ndims);
180
181 for (int i = 0; i < n_; ++i) {
182 const auto dim = src_mds_[i].dims[concat_dim_];
183 dims[concat_dim_] = dim;
184
185 memory_desc_t src_img_d;
186 status_t status = memory_desc_init_submemory(
187 src_img_d, dst_md_, dims, offsets);
188 if (status != status::success) {
189 desired_format_ok = false;
190 break;
191 }
192 offsets[concat_dim_] += dim;
193 }
194
195 if (!desired_format_ok) status = status::unimplemented;
196 }
197
198 /* if no success so far, try using the format of the first plain input */
199 if (status != status::success) {
200 for (int i = 0; i < n_; ++i) {
201 const memory_desc_wrapper src_d(src_mds_[i]);
202 // Dim of `1` may tweak a destination format leading to
203 // sub-optimal performance. Limit it to an axis case to allow
204 // case like a:a->ab or a:ab->ab to work properly.
205 // TODO: update the whole logic to getting string tags of
206 // sources but discarding dims of one. If ndims of any source
207 // coincides with dst ndims, use that tag (if they are same).
208 // If dst has +1 ndim (due to concat dim), use slices as dense
209 // layers inside a dst, which means axis should be the least
210 // dense dimension.
211 const bool axis_dim_has_one = src_d.dims()[concat_dim()] == 1;
212 if (!axis_dim_has_one && src_d.is_blocking_desc()
213 && src_d.is_plain() && src_d.nelems() > 0) {
214 status = memory_desc_init_by_blocking_desc(dst_md_,
215 memory_desc_wrapper(src_mds_[i]).blocking_desc());
216 if (status == status::success) return status;
217 }
218 }
219 }
220
221 /* the last line of defense: use plain abcd... format */
222 if (status != status::success)
223 status = memory_desc_init_by_strides(dst_md_, nullptr);
224
225 return status;
226 }
227
228private:
229 void init_desc() {
230 desc_ = concat_desc_t();
231 desc_.primitive_kind = primitive_kind::concat;
232 desc_.dst_md = &original_dst_;
233 desc_.n = n_;
234 desc_.concat_dimension = concat_dim_;
235 for (const auto &md : src_mds_)
236 desc_.src_mds.push_back(&md);
237 }
238};
239
240#define DECLARE_CONCAT_PD_t(impl_name, ...) \
241 static status_t create(concat_pd_t **concat_pd, engine_t *engine, \
242 const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, \
243 int concat_dim, const memory_desc_t *const *src_mds) { \
244 using namespace status; \
245 auto _pd = new pd_t(attr, dst_md, n, concat_dim, src_mds); \
246 if (_pd == nullptr) return out_of_memory; \
247 if (_pd->init(engine) != success) { \
248 delete _pd; \
249 return unimplemented; \
250 } \
251 _pd->init_scratchpad_md(); \
252 return safe_ptr_assign(*concat_pd, _pd); \
253 } \
254 status_t create_primitive( \
255 std::pair<std::shared_ptr<primitive_t>, bool> &primitive, \
256 engine_t *engine, const cache_blob_t &cache_blob) const override { \
257 return primitive_t::create_primitive_common<__VA_ARGS__, pd_t>( \
258 primitive, this, engine, false, cache_blob); \
259 } \
260 pd_t *clone() const override { \
261 auto new_pd = utils::make_unique<pd_t>(*this); \
262 if (!new_pd->is_initialized()) return nullptr; \
263 return new_pd.release(); \
264 } \
265 const char *name() const override { return impl_name; }
266
267#define DECLARE_CONCAT_PD_T(impl_name, ...) \
268 DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
269
270} // namespace impl
271} // namespace dnnl
272
273#endif
274