1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18
19#include "common/primitive_exec_types.hpp"
20
21#include "gpu/ocl/gen9_concat.hpp"
22#include "gpu/ocl/ocl_utils.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29std::pair<int, int> gen9_concat_t::pd_t::calculate_iter_dim_idx_chunk(
30 int num_threads) const {
31 if (conf.ndims == 1) return std::make_pair(0, 1);
32 const auto &dst_dims = conf.dst_md_info.padded_dims;
33 int max_dim_idx = -1;
34 int max_dim = -1;
35 for (int dim_idx = conf.ndims - 1; dim_idx >= 0; dim_idx--) {
36 if (dst_dims[dim_idx] > max_dim && dim_idx != conf.concat_axis) {
37 max_dim = dst_dims[dim_idx];
38 max_dim_idx = dim_idx;
39 }
40 }
41 const int iter_dim_idx = max_dim_idx;
42 const int all_elems = utils::array_product(dst_dims, conf.ndims);
43 const int max_iter_dim_chunk = 1024;
44 const int min_threads = num_threads * 4;
45 int iter_dim_chunk = std::min(dst_dims[iter_dim_idx], max_iter_dim_chunk);
46 const auto get_num_threads = [&]() {
47 return ceil(static_cast<float>(all_elems)
48 / (iter_dim_chunk * conf.sub_group_size));
49 };
50 while (get_num_threads() < min_threads && iter_dim_chunk > 1) {
51 iter_dim_chunk = ceil(iter_dim_chunk / 2.0f);
52 }
53 return std::make_pair(iter_dim_idx, iter_dim_chunk);
54}
55
56bool gen9_concat_t::pd_t::can_use_sub_group_size(
57 const compute::compute_engine_t *compute_engine, int sub_group_size) {
58 auto is_dim_dense = [](const memory_desc_wrapper &mdw, int dim_idx) {
59 return mdw.blocking_desc().strides[dim_idx] == 1;
60 };
61 auto get_dim_block = [](const memory_desc_wrapper &mdw, int dim_idx) {
62 const auto &blk = mdw.blocking_desc();
63 if (blk.inner_nblks == 0
64 || blk.inner_idxs[blk.inner_nblks - 1] != dim_idx)
65 return static_cast<dnnl_dim_t>(1);
66 return blk.inner_blks[blk.inner_nblks - 1];
67 };
68
69 const concat_pd_t *pd = this;
70 const int c_idx = 1;
71 const memory_desc_wrapper dst_mdw(pd->dst_md());
72 const bool is_dst_blocked = dst_mdw.blocking_desc().inner_nblks > 0;
73 bool is_concat_axis_aligned = true;
74 bool layouts_compatible = is_dst_blocked
75 ? get_dim_block(dst_mdw, c_idx) % sub_group_size == 0
76 : is_dim_dense(dst_mdw, c_idx);
77 for (int i = 0; i < conf.n; ++i) {
78 const memory_desc_wrapper src_mdw(pd->src_md(i));
79 is_concat_axis_aligned = is_concat_axis_aligned
80 && src_mdw.md_->dims[conf.concat_axis] % sub_group_size == 0;
81 if (is_dst_blocked) {
82 layouts_compatible = layouts_compatible
83 && get_dim_block(src_mdw, c_idx) % sub_group_size == 0;
84 } else {
85 layouts_compatible
86 = layouts_compatible && is_dim_dense(src_mdw, c_idx);
87 }
88 }
89 return is_concat_axis_aligned && layouts_compatible
90 && compute_engine->mayiuse_sub_group(sub_group_size);
91}
92
93int gen9_concat_t::pd_t::calculate_sub_group_size(
94 const compute::compute_engine_t *compute_engine) {
95 // Subgroups are used only for concatenation over C dimension
96 if (conf.concat_axis != 1) return 1;
97 for (int sub_group_size : {16, 8}) {
98 if (can_use_sub_group_size(compute_engine, sub_group_size)) {
99 return sub_group_size;
100 }
101 }
102 return 1;
103}
104
105status_t gen9_concat_t::pd_t::init_conf(engine_t *engine) {
106 const concat_pd_t *pd = this;
107
108 const memory_desc_wrapper dst_mdw(pd->dst_md());
109 const auto dst_dims = dst_mdw.md_->padded_dims;
110
111 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
112 conf.dst_type = dst_mdw.data_type();
113 conf.dst_offset0 = dst_mdw.offset0();
114 conf.src_type = memory_desc_wrapper(pd->src_md(0)).data_type();
115 conf.ndims = dst_mdw.ndims();
116 const auto *compute_engine
117 = utils::downcast<compute::compute_engine_t *>(engine);
118 conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_);
119 conf.n = pd->n_inputs();
120 conf.concat_axis = pd->concat_dim();
121
122 int concat_axis_end = 0;
123 for (int i = 0; i < conf.n; ++i) {
124 const memory_desc_wrapper src_mdw(pd->src_md(i));
125 concat_axis_end += src_mdw.md_->dims[conf.concat_axis];
126 conf.offset[i] = concat_axis_end;
127 conf.src_md_infos[i] = memory_desc_info_t::create(pd->src_md(i));
128 }
129
130 conf.sub_group_size = calculate_sub_group_size(compute_engine);
131 std::tie(conf.iter_dim_idx, conf.iter_dim_chunk)
132 = calculate_iter_dim_idx_chunk(
133 compute_engine->device_info()->hw_threads());
134
135 if (dst_mdw.blocking_desc().inner_nblks == 0
136 && (conf.sub_group_size == 1
137 || (conf.ndims > 2 && conf.iter_dim_chunk == 1))) {
138 return status::unimplemented;
139 }
140
141 for (int dim_idx = 0; dim_idx < MAX_NDIMS; dim_idx++) {
142 const int dim_block
143 = conf.iter_dim_idx == dim_idx ? conf.iter_dim_chunk : 1;
144 const int dim_size = conf.ndims > dim_idx ? dst_dims[dim_idx] : 1;
145 conf.dispatch.define_dim(
146 utils::format("D%d", dim_idx), 0, dim_size, dim_block);
147 }
148 if (conf.sub_group_size > 1) {
149 conf.dispatch.vectorize_dim("D1", conf.sub_group_size);
150 }
151 conf.dispatch.generate();
152
153 return status::success;
154}
155
156static status_t init_kernel_ctx_common(
157 compute::kernel_ctx_t &kernel_ctx, const concat_conf_t &conf) {
158 for (int i = 0; i < conf.n; ++i) {
159 kernel_ctx.define_int(utils::format("SRC%d_END", i), conf.offset[i]);
160 def_memory_desc_info(kernel_ctx, conf.src_md_infos[i],
161 utils::format("SRC%d", i).c_str());
162 }
163 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
164
165 kernel_ctx.set_data_type(conf.src_type);
166
167 kernel_ctx.define_int("NDIMS", conf.ndims);
168 kernel_ctx.define_int("CONCAT_AXIS", conf.concat_axis);
169 kernel_ctx.define_int("NUM_INPUTS", conf.n);
170 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
171 kernel_ctx.define_int("VECT_DT_N", 1);
172 kernel_ctx.define_int("ITER_DIM_PADDED_SIZE",
173 conf.dst_md_info.padded_dims[conf.iter_dim_idx]);
174 kernel_ctx.define_int("ITER_DIM_IDX", conf.iter_dim_idx);
175 kernel_ctx.define_int("ITER_DIM_CHUNK", conf.iter_dim_chunk);
176
177 def_dispatch(kernel_ctx, conf.dispatch);
178
179 return status::success;
180}
181
182status_t gen9_concat_t::pd_t::init_kernel_ctx(
183 compute::kernel_ctx_t &kernel_ctx) const {
184 return init_kernel_ctx_common(kernel_ctx, conf);
185}
186
187status_t gen9_concat_t::execute_concat(const exec_ctx_t &ctx) const {
188 status_t status;
189 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
190
191 const auto &conf = pd()->conf;
192 compute::kernel_arg_list_t arg_list;
193 arg_list.set(0, dst);
194 arg_list.set(1, conf.dst_offset0);
195 for (int i = 0; i < 16; ++i) {
196 auto &src = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + i);
197 arg_list.set(i + 2, src);
198 }
199
200 auto nd_range = conf.dispatch.nd_range();
201
202 status = parallel_for(ctx, nd_range, kernel, arg_list);
203 return status;
204}
205
206} // namespace ocl
207} // namespace gpu
208} // namespace impl
209} // namespace dnnl
210