1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <algorithm> |
17 | |
18 | #include "gpu/compute/dispatch.hpp" |
19 | #include "gpu/ocl/simple_concat.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | /* Returns dimension indices in (our best guess at) nesting order */ |
27 | std::vector<int> get_ordered_dim_idxs(const memory_desc_wrapper &mdw) { |
28 | const auto ndims = mdw.ndims(); |
29 | std::vector<int> idxs(ndims); |
30 | for (int i = 0; i < ndims; ++i) |
31 | idxs[i] = i; |
32 | |
33 | const auto &strides = mdw.blocking_desc().strides; |
34 | const auto &sizes = mdw.dims(); |
35 | const auto cmp = [&](int a, int b) { |
36 | // Dimensions of size 1 have the same stride as the next outer |
37 | // dimension, so only sorting by strides will not necessarily get the |
38 | // correct order. In the case of ties, the dim with the larger size gets |
39 | // sorted second. Permutations of dims of size 1 with the same stride |
40 | // should not effect the correctness of concat. |
41 | return strides[a] < strides[b] |
42 | || (strides[a] == strides[b] && sizes[a] < sizes[b]); |
43 | }; |
44 | std::sort(idxs.begin(), idxs.end(), cmp); |
45 | return idxs; |
46 | } |
47 | |
48 | /** Returns true if two sets of data have the same order of axis. */ |
49 | bool is_same_axis_order( |
50 | const std::vector<int> &idxs, const memory_desc_wrapper &mdw) { |
51 | const auto ndims = mdw.ndims(); |
52 | |
53 | // Compute the total size of blocks for each dim to help predict strides |
54 | std::vector<dim_t> blocks(ndims, 1); |
55 | const auto &blkg = mdw.blocking_desc(); |
56 | for (int i = 0; i < blkg.inner_nblks; ++i) |
57 | blocks[blkg.inner_idxs[i]] *= blkg.inner_blks[i]; |
58 | |
59 | // Check that the order specified by idxs matches the src tensor |
60 | dim_t min_stride = 1; |
61 | const auto &padded_dims = mdw.padded_dims(); |
62 | for (auto idx : idxs) { |
63 | auto stride = blkg.strides[idx]; |
64 | if (stride < min_stride) return false; |
65 | auto step = utils::div_up(padded_dims[idx], blocks[idx]); |
66 | min_stride = stride * step; |
67 | } |
68 | return true; |
69 | } |
70 | |
71 | static status_t init_conf_common( |
72 | engine_t *engine, concat_conf_t &conf, const concat_pd_t *pd) { |
73 | using namespace utils; |
74 | |
75 | const memory_desc_wrapper dst_mdw(pd->dst_md()); |
76 | auto ndims = dst_mdw.ndims(); |
77 | auto nelems = dst_mdw.nelems(true); |
78 | auto data_type_size = dst_mdw.data_type_size(); |
79 | |
80 | if (nelems == 0) return status::unimplemented; |
81 | |
82 | const auto &blk = dst_mdw.blocking_desc(); |
83 | const auto concat_dim = pd->concat_dim(); |
84 | // TODO: refactor to avoid duplication in get_ordered_dim_idxs and |
85 | // is_same_axis_order. |
86 | dim_t extern_axis = 1; |
87 | int extern_dim = -1; |
88 | bool equal_strides_ok = dst_mdw.padded_dims()[concat_dim] == 1; |
89 | std::vector<dim_t> blocks(ndims, 1); |
90 | for (int i = 0; i < blk.inner_nblks; ++i) |
91 | blocks[blk.inner_idxs[i]] *= blk.inner_blks[i]; |
92 | for (int i = 0; i < ndims; ++i) { |
93 | const auto &stride = blk.strides[i]; |
94 | if (stride > blk.strides[concat_dim] |
95 | || (equal_strides_ok && stride == blk.strides[concat_dim])) { |
96 | if (extern_dim == -1 || stride < blk.strides[extern_dim]) |
97 | extern_dim = i; |
98 | extern_axis *= dst_mdw.padded_dims()[i] / blocks[i]; |
99 | } |
100 | } |
101 | |
102 | int offset = 0; |
103 | bool has_padding = false; |
104 | const auto dst_dim_order = get_ordered_dim_idxs(dst_mdw); |
105 | const dim_t c_blks = blocks[concat_dim]; |
106 | for (int i = 0; i < pd->n_inputs(); ++i) { |
107 | const memory_desc_wrapper src_mdw(pd->src_md(i)); |
108 | |
109 | // check concat dim padding |
110 | if (src_mdw.padded_dims()[concat_dim] != src_mdw.dims()[concat_dim]) { |
111 | if (has_padding) |
112 | return status::unimplemented; |
113 | else |
114 | has_padding = true; |
115 | } |
116 | |
117 | if (src_mdw.data_type() != dst_mdw.data_type()) |
118 | return status::unimplemented; |
119 | |
120 | if (!types::blocking_desc_is_equal(*pd->dst_md(), *pd->src_md(i), true)) |
121 | return status::unimplemented; |
122 | |
123 | if (!is_same_axis_order(dst_dim_order, src_mdw)) |
124 | return status::unimplemented; |
125 | |
126 | if (!src_mdw.is_dense()) return status::unimplemented; |
127 | |
128 | const auto &src_blk = src_mdw.blocking_desc(); |
129 | const auto step = src_mdw.padded_dims()[concat_dim] / c_blks; |
130 | auto src_extern_dim_size = (extern_dim == -1) |
131 | ? src_blk.strides[concat_dim] * step |
132 | : src_blk.strides[extern_dim]; |
133 | conf.src_extern_dim_sizes[i] = src_extern_dim_size * data_type_size; |
134 | conf.offset[i] = offset; |
135 | offset += step; |
136 | } |
137 | |
138 | auto concat_dim_size = dst_mdw.padded_dims()[concat_dim] / c_blks; |
139 | conf.dst_extern_dim_size = (extern_dim == -1) |
140 | ? blk.strides[concat_dim] * concat_dim_size |
141 | : blk.strides[extern_dim]; |
142 | conf.inner_axis = blk.strides[concat_dim] * data_type_size; |
143 | conf.n = pd->n_inputs(); |
144 | |
145 | auto shift_in = [&concat_dim_size, &conf](int k) { |
146 | // partition concat_dim_size so that more data is read at once |
147 | if (concat_dim_size % k) return false; |
148 | for (int i = 0; i < conf.n; ++i) |
149 | if (conf.offset[i] % k) return false; |
150 | for (int i = 0; i < conf.n; ++i) |
151 | conf.offset[i] /= k; |
152 | concat_dim_size /= k; |
153 | conf.inner_axis *= k; |
154 | return true; |
155 | }; |
156 | while (shift_in(2)) |
157 | ; |
158 | for (auto k : {3, 5, 7}) |
159 | shift_in(k); |
160 | |
161 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
162 | conf.data_type_size = (conf.inner_axis % 32 == 0) ? 4 : 2; |
163 | conf.inner_axis /= conf.data_type_size; |
164 | |
165 | conf.dst_extern_dim_size |
166 | = conf.dst_extern_dim_size * data_type_size / conf.data_type_size; |
167 | conf.dst_offset0 = dst_mdw.offset0() * data_type_size / conf.data_type_size; |
168 | |
169 | auto set_gws_d = [&conf, extern_axis, concat_dim_size]() { |
170 | conf.gws_d[0] = conf.inner_axis / conf.block * conf.simd; |
171 | conf.gws_d[1] = extern_axis; |
172 | conf.gws_d[2] = concat_dim_size; |
173 | }; |
174 | |
175 | if (conf.inner_axis % 16 || conf.inner_axis < 32) { |
176 | // TODO: fix implementation so this check isn't necessary |
177 | if (data_type_size > 1) { |
178 | conf.simd = 1; |
179 | conf.block = 1; |
180 | set_gws_d(); |
181 | for (int i = 0; i < 3; ++i) { |
182 | if (conf.gws_d[i] > 1024) return status::unimplemented; |
183 | } |
184 | } else |
185 | return status::unimplemented; |
186 | } else { |
187 | conf.simd = (conf.inner_axis % 16 == 0) ? 16 : 8; |
188 | conf.block = conf.simd * utils::max_div(conf.inner_axis / conf.simd, 8); |
189 | if (!compute_engine->mayiuse_sub_group(conf.simd)) |
190 | return status::unimplemented; |
191 | set_gws_d(); |
192 | } |
193 | |
194 | compute::get_optimal_lws(conf.gws_d, conf.lws_d, 3, 0, |
195 | compute_engine->device_info()->gpu_arch()); |
196 | return status::success; |
197 | } |
198 | |
199 | static status_t init_kernel_ctx_common( |
200 | compute::kernel_ctx_t &kernel_ctx, const concat_conf_t &conf) { |
201 | kernel_ctx.define_int("DST_EXT_OFFSET" , conf.dst_extern_dim_size); |
202 | for (int i = 0; i < conf.n; ++i) { |
203 | kernel_ctx.define_int(utils::format("SRC%d_EXT_OFFSET" , i), |
204 | conf.src_extern_dim_sizes[i] / conf.data_type_size); |
205 | kernel_ctx.define_int(utils::format("OFFSET%d" , i), conf.offset[i]); |
206 | } |
207 | kernel_ctx.define_int(utils::format("OFFSET%d" , conf.n), conf.gws_d[2]); |
208 | kernel_ctx.define_int("INNER_OFFSET" , conf.inner_axis); |
209 | kernel_ctx.define_int("BLOCK" , conf.block); |
210 | kernel_ctx.define_int("N_INPUTS" , conf.n); |
211 | kernel_ctx.define_int("SIMD" , conf.simd); |
212 | kernel_ctx.define_int("DATA_TYPE_SIZE" , conf.data_type_size); |
213 | kernel_ctx.print_options(); |
214 | return status::success; |
215 | } |
216 | |
217 | status_t simple_concat_t::pd_t::init_conf(engine_t *engine) { |
218 | return init_conf_common(engine, conf, this); |
219 | } |
220 | |
221 | status_t simple_concat_t::pd_t::init_kernel_ctx( |
222 | compute::kernel_ctx_t &kernel_ctx) const { |
223 | return init_kernel_ctx_common(kernel_ctx, conf); |
224 | } |
225 | status_t simple_concat_t::execute_concat(const exec_ctx_t &ctx) const { |
226 | |
227 | const auto &conf = pd()->conf; |
228 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
229 | |
230 | compute::kernel_arg_list_t arg_list; |
231 | arg_list.set(0, dst); |
232 | arg_list.set(1, conf.dst_offset0); |
233 | for (int i = 0; i < pd()->n_inputs(); ++i) { |
234 | auto &src = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + i); |
235 | arg_list.set(i + 2, src); |
236 | } |
237 | |
238 | auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d); |
239 | |
240 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
241 | return status; |
242 | } |
243 | } // namespace ocl |
244 | } // namespace gpu |
245 | } // namespace impl |
246 | } // namespace dnnl |
247 | |