1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_SIMPLE_CONCAT_HPP |
18 | #define CPU_SIMPLE_CONCAT_HPP |
19 | |
20 | #include "common/memory_tracking.hpp" |
21 | #include "common/primitive.hpp" |
22 | |
23 | #include "cpu/platform.hpp" |
24 | |
25 | #include "cpu/cpu_concat_pd.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | template <data_type_t data_type> |
32 | struct simple_concat_t : public primitive_t { |
33 | struct pd_t : public cpu_concat_pd_t { |
34 | using cpu_concat_pd_t::cpu_concat_pd_t; |
35 | |
36 | pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) { copy_from(rhs); } |
37 | |
38 | DECLARE_CONCAT_PD_T("simple:any" , simple_concat_t); |
39 | |
40 | status_t init(engine_t *engine) { |
41 | const memory_desc_wrapper dst_d(dst_md()); |
42 | bool ok = platform::has_data_type_support(data_type) |
43 | && attr()->has_default_values() |
44 | && cpu_concat_pd_t::init() == status::success |
45 | && dst_d.ndims() <= 6; |
46 | if (!ok) return status::unimplemented; |
47 | |
48 | for (size_t i = 0; i < src_mds_.size(); ++i) { |
49 | const memory_desc_wrapper i_d(&src_mds_[i]); |
50 | const memory_desc_wrapper o_d(&src_image_mds_[i]); |
51 | |
52 | const bool ignore_strides = true; |
53 | |
54 | ok = ok |
55 | && utils::everyone_is( |
56 | data_type, i_d.data_type(), o_d.data_type()) |
57 | && utils::everyone_is(format_kind::blocked, |
58 | i_d.format_kind(), o_d.format_kind()) |
59 | && types::blocking_desc_is_equal( |
60 | *i_d.md_, *o_d.md_, ignore_strides) |
61 | && types::blocking_desc_is_equal( |
62 | *i_d.md_, *dst_d.md_, ignore_strides) |
63 | && !i_d.is_additional_buffer(); |
64 | if (!ok) return status::unimplemented; |
65 | } |
66 | |
67 | dst_d.compute_blocks(blocks_); |
68 | format_perm(); |
69 | |
70 | // start dim is the first dimension after which the concatenation |
71 | // would happen contiguously |
72 | const int start_dim = perm_[concat_dim()]; |
73 | |
74 | // check that contiguous part is indeed contiguous (i.e. dense) |
75 | if (nelems_to_concat(dst_d) |
76 | != dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] |
77 | * dst_d.blocking_desc().strides[concat_dim()]) |
78 | return status::unimplemented; |
79 | |
80 | // check that all inputs have the same strides for the |
81 | // contiguous part [concat_dim .. ndims] for the *major* dims. |
82 | // the block part is already checked above |
83 | for (size_t i = 0; i < src_mds_.size(); ++i) { |
84 | const memory_desc_wrapper i_d(&src_mds_[i]); |
85 | for (int d = start_dim; d < dst_d.ndims(); ++d) { |
86 | if (dst_d.blocking_desc().strides[iperm_[d]] |
87 | != i_d.blocking_desc().strides[iperm_[d]]) |
88 | return status::unimplemented; |
89 | } |
90 | } |
91 | |
92 | init_scratchpad(); |
93 | |
94 | return status::success; |
95 | } |
96 | |
97 | int perm_[DNNL_MAX_NDIMS] {}; |
98 | int iperm_[DNNL_MAX_NDIMS] {}; |
99 | dims_t blocks_ {}; |
100 | |
101 | dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { |
102 | const int ndims = data_d.ndims(); |
103 | |
104 | dim_t nelems = 1; |
105 | for (int i = perm_[concat_dim()]; i < ndims; i++) |
106 | nelems *= data_d.padded_dims()[iperm_[i]] / blocks_[iperm_[i]]; |
107 | for (int i = 0; i < ndims; i++) |
108 | nelems *= blocks_[i]; |
109 | |
110 | return nelems; |
111 | } |
112 | |
113 | private: |
114 | void format_perm() { |
115 | const memory_desc_wrapper dst_d(dst_md()); |
116 | const int ndims = dst_d.ndims(); |
117 | |
118 | dims_t blocks = {0}; |
119 | dst_d.compute_blocks(blocks); |
120 | |
121 | strides_t strides = {0}; |
122 | utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); |
123 | |
124 | dims_t ou_blocks = {0}; |
125 | utils::array_copy(ou_blocks, dst_d.padded_dims(), ndims); |
126 | |
127 | for (int d = 0; d < ndims; d++) { |
128 | iperm_[d] = d; |
129 | ou_blocks[d] /= blocks[d]; |
130 | } |
131 | |
132 | utils::simultaneous_sort(strides, ou_blocks, iperm_, ndims, |
133 | [](stride_t a, stride_t b) { return b - a; }); |
134 | |
135 | for (int i = 0; i < ndims; i++) |
136 | perm_[iperm_[i]] = i; |
137 | } |
138 | |
139 | void init_scratchpad() { |
140 | using namespace memory_tracking::names; |
141 | auto scratchpad = scratchpad_registry().registrar(); |
142 | scratchpad.template book<data_t *>(key_concat_iptrs, n_inputs()); |
143 | scratchpad.template book<data_t *>(key_concat_optrs, n_inputs()); |
144 | scratchpad.template book<dim_t>(key_concat_nelems, n_inputs()); |
145 | scratchpad.template book<strides_t>( |
146 | key_concat_istrides, n_inputs()); |
147 | } |
148 | |
149 | void copy_from(const pd_t &rhs) { |
150 | int ndims = rhs.dst_md_.ndims; |
151 | utils::array_copy(perm_, rhs.perm_, ndims); |
152 | utils::array_copy(iperm_, rhs.iperm_, ndims); |
153 | utils::array_copy(blocks_, rhs.blocks_, ndims); |
154 | } |
155 | }; |
156 | |
157 | simple_concat_t(const pd_t *apd) : primitive_t(apd) {} |
158 | |
159 | status_t execute(const exec_ctx_t &ctx) const override; |
160 | |
161 | typedef typename prec_traits<data_type>::type data_t; |
162 | |
163 | private: |
164 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
165 | }; |
166 | |
167 | } // namespace cpu |
168 | } // namespace impl |
169 | } // namespace dnnl |
170 | |
171 | #endif |
172 | |