1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <math.h> |
18 | |
19 | #include "common/primitive_exec_types.hpp" |
20 | |
21 | #include "gpu/ocl/gen9_concat.hpp" |
22 | #include "gpu/ocl/ocl_utils.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | std::pair<int, int> gen9_concat_t::pd_t::calculate_iter_dim_idx_chunk( |
30 | int num_threads) const { |
31 | if (conf.ndims == 1) return std::make_pair(0, 1); |
32 | const auto &dst_dims = conf.dst_md_info.padded_dims; |
33 | int max_dim_idx = -1; |
34 | int max_dim = -1; |
35 | for (int dim_idx = conf.ndims - 1; dim_idx >= 0; dim_idx--) { |
36 | if (dst_dims[dim_idx] > max_dim && dim_idx != conf.concat_axis) { |
37 | max_dim = dst_dims[dim_idx]; |
38 | max_dim_idx = dim_idx; |
39 | } |
40 | } |
41 | const int iter_dim_idx = max_dim_idx; |
42 | const int all_elems = utils::array_product(dst_dims, conf.ndims); |
43 | const int max_iter_dim_chunk = 1024; |
44 | const int min_threads = num_threads * 4; |
45 | int iter_dim_chunk = std::min(dst_dims[iter_dim_idx], max_iter_dim_chunk); |
46 | const auto get_num_threads = [&]() { |
47 | return ceil(static_cast<float>(all_elems) |
48 | / (iter_dim_chunk * conf.sub_group_size)); |
49 | }; |
50 | while (get_num_threads() < min_threads && iter_dim_chunk > 1) { |
51 | iter_dim_chunk = ceil(iter_dim_chunk / 2.0f); |
52 | } |
53 | return std::make_pair(iter_dim_idx, iter_dim_chunk); |
54 | } |
55 | |
56 | bool gen9_concat_t::pd_t::can_use_sub_group_size( |
57 | const compute::compute_engine_t *compute_engine, int sub_group_size) { |
58 | auto is_dim_dense = [](const memory_desc_wrapper &mdw, int dim_idx) { |
59 | return mdw.blocking_desc().strides[dim_idx] == 1; |
60 | }; |
61 | auto get_dim_block = [](const memory_desc_wrapper &mdw, int dim_idx) { |
62 | const auto &blk = mdw.blocking_desc(); |
63 | if (blk.inner_nblks == 0 |
64 | || blk.inner_idxs[blk.inner_nblks - 1] != dim_idx) |
65 | return static_cast<dnnl_dim_t>(1); |
66 | return blk.inner_blks[blk.inner_nblks - 1]; |
67 | }; |
68 | |
69 | const concat_pd_t *pd = this; |
70 | const int c_idx = 1; |
71 | const memory_desc_wrapper dst_mdw(pd->dst_md()); |
72 | const bool is_dst_blocked = dst_mdw.blocking_desc().inner_nblks > 0; |
73 | bool is_concat_axis_aligned = true; |
74 | bool layouts_compatible = is_dst_blocked |
75 | ? get_dim_block(dst_mdw, c_idx) % sub_group_size == 0 |
76 | : is_dim_dense(dst_mdw, c_idx); |
77 | for (int i = 0; i < conf.n; ++i) { |
78 | const memory_desc_wrapper src_mdw(pd->src_md(i)); |
79 | is_concat_axis_aligned = is_concat_axis_aligned |
80 | && src_mdw.md_->dims[conf.concat_axis] % sub_group_size == 0; |
81 | if (is_dst_blocked) { |
82 | layouts_compatible = layouts_compatible |
83 | && get_dim_block(src_mdw, c_idx) % sub_group_size == 0; |
84 | } else { |
85 | layouts_compatible |
86 | = layouts_compatible && is_dim_dense(src_mdw, c_idx); |
87 | } |
88 | } |
89 | return is_concat_axis_aligned && layouts_compatible |
90 | && compute_engine->mayiuse_sub_group(sub_group_size); |
91 | } |
92 | |
93 | int gen9_concat_t::pd_t::calculate_sub_group_size( |
94 | const compute::compute_engine_t *compute_engine) { |
95 | // Subgroups are used only for concatenation over C dimension |
96 | if (conf.concat_axis != 1) return 1; |
97 | for (int sub_group_size : {16, 8}) { |
98 | if (can_use_sub_group_size(compute_engine, sub_group_size)) { |
99 | return sub_group_size; |
100 | } |
101 | } |
102 | return 1; |
103 | } |
104 | |
105 | status_t gen9_concat_t::pd_t::init_conf(engine_t *engine) { |
106 | const concat_pd_t *pd = this; |
107 | |
108 | const memory_desc_wrapper dst_mdw(pd->dst_md()); |
109 | const auto dst_dims = dst_mdw.md_->padded_dims; |
110 | |
111 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
112 | conf.dst_type = dst_mdw.data_type(); |
113 | conf.dst_offset0 = dst_mdw.offset0(); |
114 | conf.src_type = memory_desc_wrapper(pd->src_md(0)).data_type(); |
115 | conf.ndims = dst_mdw.ndims(); |
116 | const auto *compute_engine |
117 | = utils::downcast<compute::compute_engine_t *>(engine); |
118 | conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_); |
119 | conf.n = pd->n_inputs(); |
120 | conf.concat_axis = pd->concat_dim(); |
121 | |
122 | int concat_axis_end = 0; |
123 | for (int i = 0; i < conf.n; ++i) { |
124 | const memory_desc_wrapper src_mdw(pd->src_md(i)); |
125 | concat_axis_end += src_mdw.md_->dims[conf.concat_axis]; |
126 | conf.offset[i] = concat_axis_end; |
127 | conf.src_md_infos[i] = memory_desc_info_t::create(pd->src_md(i)); |
128 | } |
129 | |
130 | conf.sub_group_size = calculate_sub_group_size(compute_engine); |
131 | std::tie(conf.iter_dim_idx, conf.iter_dim_chunk) |
132 | = calculate_iter_dim_idx_chunk( |
133 | compute_engine->device_info()->hw_threads()); |
134 | |
135 | if (dst_mdw.blocking_desc().inner_nblks == 0 |
136 | && (conf.sub_group_size == 1 |
137 | || (conf.ndims > 2 && conf.iter_dim_chunk == 1))) { |
138 | return status::unimplemented; |
139 | } |
140 | |
141 | for (int dim_idx = 0; dim_idx < MAX_NDIMS; dim_idx++) { |
142 | const int dim_block |
143 | = conf.iter_dim_idx == dim_idx ? conf.iter_dim_chunk : 1; |
144 | const int dim_size = conf.ndims > dim_idx ? dst_dims[dim_idx] : 1; |
145 | conf.dispatch.define_dim( |
146 | utils::format("D%d" , dim_idx), 0, dim_size, dim_block); |
147 | } |
148 | if (conf.sub_group_size > 1) { |
149 | conf.dispatch.vectorize_dim("D1" , conf.sub_group_size); |
150 | } |
151 | conf.dispatch.generate(); |
152 | |
153 | return status::success; |
154 | } |
155 | |
156 | static status_t init_kernel_ctx_common( |
157 | compute::kernel_ctx_t &kernel_ctx, const concat_conf_t &conf) { |
158 | for (int i = 0; i < conf.n; ++i) { |
159 | kernel_ctx.define_int(utils::format("SRC%d_END" , i), conf.offset[i]); |
160 | def_memory_desc_info(kernel_ctx, conf.src_md_infos[i], |
161 | utils::format("SRC%d" , i).c_str()); |
162 | } |
163 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
164 | |
165 | kernel_ctx.set_data_type(conf.src_type); |
166 | |
167 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
168 | kernel_ctx.define_int("CONCAT_AXIS" , conf.concat_axis); |
169 | kernel_ctx.define_int("NUM_INPUTS" , conf.n); |
170 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
171 | kernel_ctx.define_int("VECT_DT_N" , 1); |
172 | kernel_ctx.define_int("ITER_DIM_PADDED_SIZE" , |
173 | conf.dst_md_info.padded_dims[conf.iter_dim_idx]); |
174 | kernel_ctx.define_int("ITER_DIM_IDX" , conf.iter_dim_idx); |
175 | kernel_ctx.define_int("ITER_DIM_CHUNK" , conf.iter_dim_chunk); |
176 | |
177 | def_dispatch(kernel_ctx, conf.dispatch); |
178 | |
179 | return status::success; |
180 | } |
181 | |
182 | status_t gen9_concat_t::pd_t::init_kernel_ctx( |
183 | compute::kernel_ctx_t &kernel_ctx) const { |
184 | return init_kernel_ctx_common(kernel_ctx, conf); |
185 | } |
186 | |
187 | status_t gen9_concat_t::execute_concat(const exec_ctx_t &ctx) const { |
188 | status_t status; |
189 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
190 | |
191 | const auto &conf = pd()->conf; |
192 | compute::kernel_arg_list_t arg_list; |
193 | arg_list.set(0, dst); |
194 | arg_list.set(1, conf.dst_offset0); |
195 | for (int i = 0; i < 16; ++i) { |
196 | auto &src = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + i); |
197 | arg_list.set(i + 2, src); |
198 | } |
199 | |
200 | auto nd_range = conf.dispatch.nd_range(); |
201 | |
202 | status = parallel_for(ctx, nd_range, kernel, arg_list); |
203 | return status; |
204 | } |
205 | |
206 | } // namespace ocl |
207 | } // namespace gpu |
208 | } // namespace impl |
209 | } // namespace dnnl |
210 | |