1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_pooling.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | static status_t init_conf_common(pool_conf_t &conf, offsets_t &off, |
25 | const pooling_pd_t *pd, engine_t *engine) { |
26 | using namespace dnnl::impl::format_tag; |
27 | |
28 | const memory_desc_wrapper src_mdw(pd->invariant_src_md()); |
29 | const memory_desc_wrapper dst_mdw(pd->invariant_dst_md()); |
30 | |
31 | auto is_c_dense = [](const memory_desc_wrapper &mdw) { |
32 | return mdw.blocking_desc().strides[1] == 1; |
33 | }; |
34 | auto is_c_blocked_by |
35 | = [](const memory_desc_wrapper &mdw, const int blockSize) { |
36 | auto &blk = mdw.blocking_desc(); |
37 | if (blk.inner_nblks == 0) return false; |
38 | return (blk.inner_idxs[blk.inner_nblks - 1] == 1) |
39 | && (blk.inner_blks[blk.inner_nblks - 1] == blockSize); |
40 | }; |
41 | |
42 | if (!is_c_blocked_by(src_mdw, 16) && !is_c_blocked_by(src_mdw, 32) |
43 | && !is_c_dense(src_mdw)) |
44 | return status::unimplemented; |
45 | |
46 | if (!is_c_blocked_by(dst_mdw, 16) && !is_c_blocked_by(dst_mdw, 32) |
47 | && !is_c_dense(dst_mdw)) |
48 | return status::unimplemented; |
49 | |
50 | int c_block_size = 1, n_block_size = 1; |
51 | auto &src_blk = src_mdw.blocking_desc(); |
52 | if (src_blk.inner_nblks > 0) { |
53 | // C is the last blocked dimension as it was checked in is_c_blocked_by |
54 | c_block_size = src_blk.inner_blks[src_blk.inner_nblks - 1]; |
55 | // if there is NC blocking (N is the blocked dimension before C) use N blocks as well |
56 | if (src_blk.inner_nblks > 1 |
57 | && src_blk.inner_idxs[src_blk.inner_nblks - 2] == 0) { |
58 | n_block_size = src_blk.inner_blks[src_blk.inner_nblks - 2]; |
59 | } |
60 | } |
61 | |
62 | set_default_pool_conf(conf, *pd->desc(), *pd->invariant_src_md(), |
63 | *pd->invariant_dst_md(), *pd->attr()); |
64 | |
65 | set_offsets(src_mdw, off.src_off); |
66 | set_offsets(dst_mdw, off.dst_off); |
67 | |
68 | conf.sub_group_size = 16; |
69 | conf.use_mb_c_block = false; |
70 | conf.use_only_c_block = false; |
71 | int c_padded = utils::rnd_up(conf.c_padded, conf.sub_group_size); |
72 | |
73 | if (c_block_size >= 16 && n_block_size >= 16) { |
74 | c_padded = utils::rnd_up(conf.c_padded, c_block_size); |
75 | conf.use_mb_c_block = true; |
76 | conf.vect_dt_n = 8; |
77 | conf.nvect = 2; |
78 | if (!pd->attr()->post_ops_.has_default_values()) { conf.nvect = 1; } |
79 | conf.chunks_per_c_block = c_block_size / conf.sub_group_size; |
80 | conf.chunks_per_mb_block |
81 | = conf.vect_dt_n * conf.nvect / conf.chunks_per_c_block; |
82 | } else if (c_block_size == 16 && n_block_size == 1) { |
83 | conf.use_only_c_block = true; |
84 | conf.vect_dt_n = 1; |
85 | conf.nvect = 2; |
86 | if (!pd->attr()->post_ops_.has_default_values()) { conf.nvect = 1; } |
87 | conf.chunks_per_c_block = conf.nvect * conf.vect_dt_n; |
88 | conf.chunks_per_mb_block = 1; |
89 | } else { |
90 | conf.use_only_c_block = true; |
91 | const size_t num_c_blocks = c_padded / conf.sub_group_size; |
92 | conf.vect_dt_n = 8; |
93 | while (num_c_blocks % conf.vect_dt_n != 0) { |
94 | conf.vect_dt_n /= 2; |
95 | } |
96 | if ((conf.vect_dt_n < 8) && (conf.mb_padded % 4 == 0)) { |
97 | conf.unroll_mb = true; |
98 | } |
99 | conf.nvect = 1; |
100 | conf.chunks_per_c_block = conf.nvect * conf.vect_dt_n; |
101 | conf.chunks_per_mb_block = 1; |
102 | } |
103 | if (conf.vect_dt_n < 4) { |
104 | // fallback to ref_pooling kernel for better perf. |
105 | return status::unimplemented; |
106 | } |
107 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
108 | conf.dispatch = compute_engine->create_dispatch( |
109 | conf.is_backward ? src_mdw.md_ : dst_mdw.md_); |
110 | |
111 | auto arch = compute_engine->device_info()->gpu_arch(); |
112 | bool is_pre_xe_hpc = arch < compute::gpu_arch_t::xe_hpc; |
113 | size_t input_sz_mb |
114 | = src_mdw.nelems() * src_mdw.data_type_size() / 1024 / 1024; |
115 | // heuristics: use batching on ATS-M for certain shapes for better perf. |
116 | if (!conf.is_backward && pd->attr()->post_ops_.has_default_values() |
117 | && !conf.unroll_mb && is_pre_xe_hpc |
118 | && (2 * input_sz_mb > (size_t)conf.mb)) { |
119 | conf.num_batches = utils::div_up(conf.mb_padded, conf.mb_block_size); |
120 | } |
121 | if (conf.num_batches > 1) { |
122 | conf.dispatch.define_dim("MB" , 0, |
123 | nstl::min(conf.mb_block_size, conf.mb_padded), |
124 | conf.chunks_per_mb_block); |
125 | } else { |
126 | if (conf.is_backward) { |
127 | conf.dispatch.define_dim("MB" , 0, |
128 | conf.unroll_mb ? conf.mb_padded / 4 : conf.mb_padded, |
129 | conf.chunks_per_mb_block); |
130 | } else { |
131 | conf.dispatch.define_dim("MB" , 0, |
132 | conf.unroll_mb ? conf.mb_padded / 2 : conf.mb_padded, |
133 | conf.chunks_per_mb_block); |
134 | } |
135 | } |
136 | conf.dispatch.define_dim("C" , 1, c_padded, conf.chunks_per_c_block); |
137 | |
138 | int ndims = conf.ndims; |
139 | if (!conf.is_backward) { |
140 | conf.dispatch.define_dim("OD" , nstl::max(2, ndims - 3), conf.od); |
141 | conf.dispatch.define_dim("OH" , nstl::max(2, ndims - 2), conf.oh); |
142 | conf.dispatch.define_dim("OW" , nstl::max(2, ndims - 1), conf.ow); |
143 | } else { |
144 | conf.dispatch.define_dim("ID" , nstl::max(2, ndims - 3), conf.id); |
145 | conf.dispatch.define_dim("IH" , nstl::max(2, ndims - 2), conf.ih); |
146 | conf.dispatch.define_dim("IW" , nstl::max(2, ndims - 1), conf.iw); |
147 | } |
148 | CHECK(conf.dispatch.vectorize_dim("C" , conf.sub_group_size)); |
149 | conf.dispatch.generate(); |
150 | |
151 | return status::success; |
152 | }; |
153 | |
154 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
155 | const pool_conf_t &conf, const offsets_t &off, |
156 | const post_ops_t &post_ops) { |
157 | using namespace dnnl::impl::alg_kind; |
158 | kernel_ctx.set_data_type(conf.src_dt); |
159 | |
160 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
161 | if (conf.num_batches > 1) { |
162 | kernel_ctx.define_int("MB" , nstl::min(conf.mb_block_size, conf.mb)); |
163 | } else { |
164 | kernel_ctx.define_int("MB" , conf.mb); |
165 | } |
166 | kernel_ctx.define_int("MB_BLOCK_SIZE" , conf.mb_block_size); |
167 | kernel_ctx.define_int("C_W_PADDING" , conf.c_padded); |
168 | kernel_ctx.define_int("C_WO_PADDING" , conf.c); |
169 | kernel_ctx.define_int("ID" , conf.id); |
170 | kernel_ctx.define_int("IH" , conf.ih); |
171 | kernel_ctx.define_int("IW" , conf.iw); |
172 | kernel_ctx.define_int("OD" , conf.od); |
173 | kernel_ctx.define_int("OH" , conf.oh); |
174 | kernel_ctx.define_int("OW" , conf.ow); |
175 | kernel_ctx.define_int("KD" , conf.kd); |
176 | kernel_ctx.define_int("KH" , conf.kh); |
177 | kernel_ctx.define_int("KW" , conf.kw); |
178 | kernel_ctx.define_int("SD" , conf.stride_d); |
179 | kernel_ctx.define_int("SH" , conf.stride_h); |
180 | kernel_ctx.define_int("SW" , conf.stride_w); |
181 | kernel_ctx.define_int("PD" , conf.f_pad); |
182 | kernel_ctx.define_int("PH" , conf.t_pad); |
183 | kernel_ctx.define_int("PW" , conf.l_pad); |
184 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
185 | kernel_ctx.define_int("IS_TRAINING" , conf.is_training); |
186 | kernel_ctx.define_int("IS_BWD" , conf.is_backward); |
187 | kernel_ctx.define_int("IS_FWD" , !conf.is_backward); |
188 | |
189 | kernel_ctx.define_int("ALG_MAX" , (conf.alg == pooling_max)); |
190 | kernel_ctx.define_int( |
191 | "ALG_AVG_NP" , (conf.alg == pooling_avg_exclude_padding)); |
192 | kernel_ctx.define_int( |
193 | "ALG_AVG_P" , (conf.alg == pooling_avg_include_padding)); |
194 | |
195 | kernel_ctx.define_int("VECT_DT_N" , conf.vect_dt_n); |
196 | kernel_ctx.define_int("NVECT" , conf.nvect); |
197 | kernel_ctx.define_int("USE_ONLY_C_BLOCK" , conf.use_only_c_block); |
198 | kernel_ctx.define_int("USE_MB_C_BLOCK" , conf.use_mb_c_block); |
199 | kernel_ctx.define_int("CHUNKS_PER_C_BLOCK" , conf.chunks_per_c_block); |
200 | kernel_ctx.define_int("CHUNKS_PER_MB_BLOCK" , conf.chunks_per_mb_block); |
201 | kernel_ctx.define_int("UNROLL_MB" , conf.unroll_mb); |
202 | |
203 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
204 | |
205 | def_offsets(off.src_off, kernel_ctx, "SRC" , conf.ndims); |
206 | def_offsets(off.dst_off, kernel_ctx, "DST" , conf.ndims); |
207 | |
208 | def_attr_info(kernel_ctx, conf.attr_info, post_ops); |
209 | |
210 | def_dispatch(kernel_ctx, conf.dispatch); |
211 | |
212 | return status::success; |
213 | } |
214 | |
215 | status_t gen9_pooling_fwd_t::pd_t::init_conf(engine_t *engine) { |
216 | return init_conf_common(conf, off, this, engine); |
217 | } |
218 | |
219 | status_t gen9_pooling_fwd_t::pd_t::init_kernel_ctx( |
220 | compute::kernel_ctx_t &kernel_ctx) const { |
221 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
222 | } |
223 | |
224 | status_t gen9_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
225 | |
226 | status_t status = status::success; |
227 | |
228 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
229 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
230 | auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE); |
231 | |
232 | compute::kernel_arg_list_t arg_list; |
233 | arg_list.set(0, src); |
234 | arg_list.set(1, ws); |
235 | arg_list.set(2, dst); |
236 | append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_); |
237 | |
238 | auto nd_range = pd()->conf.dispatch.nd_range(); |
239 | |
240 | int num_batches = pd()->conf.num_batches; |
241 | for (int batch_iter = 0; batch_iter < num_batches; batch_iter++) { |
242 | arg_list.set(3, batch_iter); |
243 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
244 | if (status != status::success) return status; |
245 | } |
246 | return status; |
247 | } |
248 | |
249 | status_t gen9_pooling_bwd_t::pd_t::init_conf(engine_t *engine) { |
250 | return init_conf_common(conf, off, this, engine); |
251 | } |
252 | |
253 | status_t gen9_pooling_bwd_t::pd_t::init_kernel_ctx( |
254 | compute::kernel_ctx_t &kernel_ctx) const { |
255 | return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_); |
256 | } |
257 | |
258 | status_t gen9_pooling_bwd_t::execute_backward(const exec_ctx_t &ctx) const { |
259 | |
260 | status_t status = status::success; |
261 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
262 | CHECK(status); |
263 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
264 | auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); |
265 | |
266 | compute::kernel_arg_list_t arg_list; |
267 | arg_list.set(0, diff_src); |
268 | arg_list.set(1, ws); |
269 | arg_list.set(2, diff_dst); |
270 | |
271 | auto nd_range = pd()->conf.dispatch.nd_range(); |
272 | |
273 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
274 | |
275 | return status; |
276 | } |
277 | |
278 | } // namespace ocl |
279 | } // namespace gpu |
280 | } // namespace impl |
281 | } // namespace dnnl |
282 | |