1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <algorithm>
17
18#include "gpu/compute/dispatch.hpp"
19#include "gpu/ocl/simple_concat.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26/* Returns dimension indices in (our best guess at) nesting order */
27std::vector<int> get_ordered_dim_idxs(const memory_desc_wrapper &mdw) {
28 const auto ndims = mdw.ndims();
29 std::vector<int> idxs(ndims);
30 for (int i = 0; i < ndims; ++i)
31 idxs[i] = i;
32
33 const auto &strides = mdw.blocking_desc().strides;
34 const auto &sizes = mdw.dims();
35 const auto cmp = [&](int a, int b) {
36 // Dimensions of size 1 have the same stride as the next outer
37 // dimension, so only sorting by strides will not necessarily get the
38 // correct order. In the case of ties, the dim with the larger size gets
39 // sorted second. Permutations of dims of size 1 with the same stride
40 // should not effect the correctness of concat.
41 return strides[a] < strides[b]
42 || (strides[a] == strides[b] && sizes[a] < sizes[b]);
43 };
44 std::sort(idxs.begin(), idxs.end(), cmp);
45 return idxs;
46}
47
48/** Returns true if two sets of data have the same order of axis. */
49bool is_same_axis_order(
50 const std::vector<int> &idxs, const memory_desc_wrapper &mdw) {
51 const auto ndims = mdw.ndims();
52
53 // Compute the total size of blocks for each dim to help predict strides
54 std::vector<dim_t> blocks(ndims, 1);
55 const auto &blkg = mdw.blocking_desc();
56 for (int i = 0; i < blkg.inner_nblks; ++i)
57 blocks[blkg.inner_idxs[i]] *= blkg.inner_blks[i];
58
59 // Check that the order specified by idxs matches the src tensor
60 dim_t min_stride = 1;
61 const auto &padded_dims = mdw.padded_dims();
62 for (auto idx : idxs) {
63 auto stride = blkg.strides[idx];
64 if (stride < min_stride) return false;
65 auto step = utils::div_up(padded_dims[idx], blocks[idx]);
66 min_stride = stride * step;
67 }
68 return true;
69}
70
71static status_t init_conf_common(
72 engine_t *engine, concat_conf_t &conf, const concat_pd_t *pd) {
73 using namespace utils;
74
75 const memory_desc_wrapper dst_mdw(pd->dst_md());
76 auto ndims = dst_mdw.ndims();
77 auto nelems = dst_mdw.nelems(true);
78 auto data_type_size = dst_mdw.data_type_size();
79
80 if (nelems == 0) return status::unimplemented;
81
82 const auto &blk = dst_mdw.blocking_desc();
83 const auto concat_dim = pd->concat_dim();
84 // TODO: refactor to avoid duplication in get_ordered_dim_idxs and
85 // is_same_axis_order.
86 dim_t extern_axis = 1;
87 int extern_dim = -1;
88 bool equal_strides_ok = dst_mdw.padded_dims()[concat_dim] == 1;
89 std::vector<dim_t> blocks(ndims, 1);
90 for (int i = 0; i < blk.inner_nblks; ++i)
91 blocks[blk.inner_idxs[i]] *= blk.inner_blks[i];
92 for (int i = 0; i < ndims; ++i) {
93 const auto &stride = blk.strides[i];
94 if (stride > blk.strides[concat_dim]
95 || (equal_strides_ok && stride == blk.strides[concat_dim])) {
96 if (extern_dim == -1 || stride < blk.strides[extern_dim])
97 extern_dim = i;
98 extern_axis *= dst_mdw.padded_dims()[i] / blocks[i];
99 }
100 }
101
102 int offset = 0;
103 bool has_padding = false;
104 const auto dst_dim_order = get_ordered_dim_idxs(dst_mdw);
105 const dim_t c_blks = blocks[concat_dim];
106 for (int i = 0; i < pd->n_inputs(); ++i) {
107 const memory_desc_wrapper src_mdw(pd->src_md(i));
108
109 // check concat dim padding
110 if (src_mdw.padded_dims()[concat_dim] != src_mdw.dims()[concat_dim]) {
111 if (has_padding)
112 return status::unimplemented;
113 else
114 has_padding = true;
115 }
116
117 if (src_mdw.data_type() != dst_mdw.data_type())
118 return status::unimplemented;
119
120 if (!types::blocking_desc_is_equal(*pd->dst_md(), *pd->src_md(i), true))
121 return status::unimplemented;
122
123 if (!is_same_axis_order(dst_dim_order, src_mdw))
124 return status::unimplemented;
125
126 if (!src_mdw.is_dense()) return status::unimplemented;
127
128 const auto &src_blk = src_mdw.blocking_desc();
129 const auto step = src_mdw.padded_dims()[concat_dim] / c_blks;
130 auto src_extern_dim_size = (extern_dim == -1)
131 ? src_blk.strides[concat_dim] * step
132 : src_blk.strides[extern_dim];
133 conf.src_extern_dim_sizes[i] = src_extern_dim_size * data_type_size;
134 conf.offset[i] = offset;
135 offset += step;
136 }
137
138 auto concat_dim_size = dst_mdw.padded_dims()[concat_dim] / c_blks;
139 conf.dst_extern_dim_size = (extern_dim == -1)
140 ? blk.strides[concat_dim] * concat_dim_size
141 : blk.strides[extern_dim];
142 conf.inner_axis = blk.strides[concat_dim] * data_type_size;
143 conf.n = pd->n_inputs();
144
145 auto shift_in = [&concat_dim_size, &conf](int k) {
146 // partition concat_dim_size so that more data is read at once
147 if (concat_dim_size % k) return false;
148 for (int i = 0; i < conf.n; ++i)
149 if (conf.offset[i] % k) return false;
150 for (int i = 0; i < conf.n; ++i)
151 conf.offset[i] /= k;
152 concat_dim_size /= k;
153 conf.inner_axis *= k;
154 return true;
155 };
156 while (shift_in(2))
157 ;
158 for (auto k : {3, 5, 7})
159 shift_in(k);
160
161 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
162 conf.data_type_size = (conf.inner_axis % 32 == 0) ? 4 : 2;
163 conf.inner_axis /= conf.data_type_size;
164
165 conf.dst_extern_dim_size
166 = conf.dst_extern_dim_size * data_type_size / conf.data_type_size;
167 conf.dst_offset0 = dst_mdw.offset0() * data_type_size / conf.data_type_size;
168
169 auto set_gws_d = [&conf, extern_axis, concat_dim_size]() {
170 conf.gws_d[0] = conf.inner_axis / conf.block * conf.simd;
171 conf.gws_d[1] = extern_axis;
172 conf.gws_d[2] = concat_dim_size;
173 };
174
175 if (conf.inner_axis % 16 || conf.inner_axis < 32) {
176 // TODO: fix implementation so this check isn't necessary
177 if (data_type_size > 1) {
178 conf.simd = 1;
179 conf.block = 1;
180 set_gws_d();
181 for (int i = 0; i < 3; ++i) {
182 if (conf.gws_d[i] > 1024) return status::unimplemented;
183 }
184 } else
185 return status::unimplemented;
186 } else {
187 conf.simd = (conf.inner_axis % 16 == 0) ? 16 : 8;
188 conf.block = conf.simd * utils::max_div(conf.inner_axis / conf.simd, 8);
189 if (!compute_engine->mayiuse_sub_group(conf.simd))
190 return status::unimplemented;
191 set_gws_d();
192 }
193
194 compute::get_optimal_lws(conf.gws_d, conf.lws_d, 3, 0,
195 compute_engine->device_info()->gpu_arch());
196 return status::success;
197}
198
199static status_t init_kernel_ctx_common(
200 compute::kernel_ctx_t &kernel_ctx, const concat_conf_t &conf) {
201 kernel_ctx.define_int("DST_EXT_OFFSET", conf.dst_extern_dim_size);
202 for (int i = 0; i < conf.n; ++i) {
203 kernel_ctx.define_int(utils::format("SRC%d_EXT_OFFSET", i),
204 conf.src_extern_dim_sizes[i] / conf.data_type_size);
205 kernel_ctx.define_int(utils::format("OFFSET%d", i), conf.offset[i]);
206 }
207 kernel_ctx.define_int(utils::format("OFFSET%d", conf.n), conf.gws_d[2]);
208 kernel_ctx.define_int("INNER_OFFSET", conf.inner_axis);
209 kernel_ctx.define_int("BLOCK", conf.block);
210 kernel_ctx.define_int("N_INPUTS", conf.n);
211 kernel_ctx.define_int("SIMD", conf.simd);
212 kernel_ctx.define_int("DATA_TYPE_SIZE", conf.data_type_size);
213 kernel_ctx.print_options();
214 return status::success;
215}
216
217status_t simple_concat_t::pd_t::init_conf(engine_t *engine) {
218 return init_conf_common(engine, conf, this);
219}
220
221status_t simple_concat_t::pd_t::init_kernel_ctx(
222 compute::kernel_ctx_t &kernel_ctx) const {
223 return init_kernel_ctx_common(kernel_ctx, conf);
224}
225status_t simple_concat_t::execute_concat(const exec_ctx_t &ctx) const {
226
227 const auto &conf = pd()->conf;
228 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
229
230 compute::kernel_arg_list_t arg_list;
231 arg_list.set(0, dst);
232 arg_list.set(1, conf.dst_offset0);
233 for (int i = 0; i < pd()->n_inputs(); ++i) {
234 auto &src = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + i);
235 arg_list.set(i + 2, src);
236 }
237
238 auto nd_range = compute::nd_range_t(conf.gws_d, conf.lws_d);
239
240 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
241 return status;
242}
243} // namespace ocl
244} // namespace gpu
245} // namespace impl
246} // namespace dnnl
247