1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gen9_pooling.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24static status_t init_conf_common(pool_conf_t &conf, offsets_t &off,
25 const pooling_pd_t *pd, engine_t *engine) {
26 using namespace dnnl::impl::format_tag;
27
28 const memory_desc_wrapper src_mdw(pd->invariant_src_md());
29 const memory_desc_wrapper dst_mdw(pd->invariant_dst_md());
30
31 auto is_c_dense = [](const memory_desc_wrapper &mdw) {
32 return mdw.blocking_desc().strides[1] == 1;
33 };
34 auto is_c_blocked_by
35 = [](const memory_desc_wrapper &mdw, const int blockSize) {
36 auto &blk = mdw.blocking_desc();
37 if (blk.inner_nblks == 0) return false;
38 return (blk.inner_idxs[blk.inner_nblks - 1] == 1)
39 && (blk.inner_blks[blk.inner_nblks - 1] == blockSize);
40 };
41
42 if (!is_c_blocked_by(src_mdw, 16) && !is_c_blocked_by(src_mdw, 32)
43 && !is_c_dense(src_mdw))
44 return status::unimplemented;
45
46 if (!is_c_blocked_by(dst_mdw, 16) && !is_c_blocked_by(dst_mdw, 32)
47 && !is_c_dense(dst_mdw))
48 return status::unimplemented;
49
50 int c_block_size = 1, n_block_size = 1;
51 auto &src_blk = src_mdw.blocking_desc();
52 if (src_blk.inner_nblks > 0) {
53 // C is the last blocked dimension as it was checked in is_c_blocked_by
54 c_block_size = src_blk.inner_blks[src_blk.inner_nblks - 1];
55 // if there is NC blocking (N is the blocked dimension before C) use N blocks as well
56 if (src_blk.inner_nblks > 1
57 && src_blk.inner_idxs[src_blk.inner_nblks - 2] == 0) {
58 n_block_size = src_blk.inner_blks[src_blk.inner_nblks - 2];
59 }
60 }
61
62 set_default_pool_conf(conf, *pd->desc(), *pd->invariant_src_md(),
63 *pd->invariant_dst_md(), *pd->attr());
64
65 set_offsets(src_mdw, off.src_off);
66 set_offsets(dst_mdw, off.dst_off);
67
68 conf.sub_group_size = 16;
69 conf.use_mb_c_block = false;
70 conf.use_only_c_block = false;
71 int c_padded = utils::rnd_up(conf.c_padded, conf.sub_group_size);
72
73 if (c_block_size >= 16 && n_block_size >= 16) {
74 c_padded = utils::rnd_up(conf.c_padded, c_block_size);
75 conf.use_mb_c_block = true;
76 conf.vect_dt_n = 8;
77 conf.nvect = 2;
78 if (!pd->attr()->post_ops_.has_default_values()) { conf.nvect = 1; }
79 conf.chunks_per_c_block = c_block_size / conf.sub_group_size;
80 conf.chunks_per_mb_block
81 = conf.vect_dt_n * conf.nvect / conf.chunks_per_c_block;
82 } else if (c_block_size == 16 && n_block_size == 1) {
83 conf.use_only_c_block = true;
84 conf.vect_dt_n = 1;
85 conf.nvect = 2;
86 if (!pd->attr()->post_ops_.has_default_values()) { conf.nvect = 1; }
87 conf.chunks_per_c_block = conf.nvect * conf.vect_dt_n;
88 conf.chunks_per_mb_block = 1;
89 } else {
90 conf.use_only_c_block = true;
91 const size_t num_c_blocks = c_padded / conf.sub_group_size;
92 conf.vect_dt_n = 8;
93 while (num_c_blocks % conf.vect_dt_n != 0) {
94 conf.vect_dt_n /= 2;
95 }
96 if ((conf.vect_dt_n < 8) && (conf.mb_padded % 4 == 0)) {
97 conf.unroll_mb = true;
98 }
99 conf.nvect = 1;
100 conf.chunks_per_c_block = conf.nvect * conf.vect_dt_n;
101 conf.chunks_per_mb_block = 1;
102 }
103 if (conf.vect_dt_n < 4) {
104 // fallback to ref_pooling kernel for better perf.
105 return status::unimplemented;
106 }
107 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
108 conf.dispatch = compute_engine->create_dispatch(
109 conf.is_backward ? src_mdw.md_ : dst_mdw.md_);
110
111 auto arch = compute_engine->device_info()->gpu_arch();
112 bool is_pre_xe_hpc = arch < compute::gpu_arch_t::xe_hpc;
113 size_t input_sz_mb
114 = src_mdw.nelems() * src_mdw.data_type_size() / 1024 / 1024;
115 // heuristics: use batching on ATS-M for certain shapes for better perf.
116 if (!conf.is_backward && pd->attr()->post_ops_.has_default_values()
117 && !conf.unroll_mb && is_pre_xe_hpc
118 && (2 * input_sz_mb > (size_t)conf.mb)) {
119 conf.num_batches = utils::div_up(conf.mb_padded, conf.mb_block_size);
120 }
121 if (conf.num_batches > 1) {
122 conf.dispatch.define_dim("MB", 0,
123 nstl::min(conf.mb_block_size, conf.mb_padded),
124 conf.chunks_per_mb_block);
125 } else {
126 if (conf.is_backward) {
127 conf.dispatch.define_dim("MB", 0,
128 conf.unroll_mb ? conf.mb_padded / 4 : conf.mb_padded,
129 conf.chunks_per_mb_block);
130 } else {
131 conf.dispatch.define_dim("MB", 0,
132 conf.unroll_mb ? conf.mb_padded / 2 : conf.mb_padded,
133 conf.chunks_per_mb_block);
134 }
135 }
136 conf.dispatch.define_dim("C", 1, c_padded, conf.chunks_per_c_block);
137
138 int ndims = conf.ndims;
139 if (!conf.is_backward) {
140 conf.dispatch.define_dim("OD", nstl::max(2, ndims - 3), conf.od);
141 conf.dispatch.define_dim("OH", nstl::max(2, ndims - 2), conf.oh);
142 conf.dispatch.define_dim("OW", nstl::max(2, ndims - 1), conf.ow);
143 } else {
144 conf.dispatch.define_dim("ID", nstl::max(2, ndims - 3), conf.id);
145 conf.dispatch.define_dim("IH", nstl::max(2, ndims - 2), conf.ih);
146 conf.dispatch.define_dim("IW", nstl::max(2, ndims - 1), conf.iw);
147 }
148 CHECK(conf.dispatch.vectorize_dim("C", conf.sub_group_size));
149 conf.dispatch.generate();
150
151 return status::success;
152};
153
154static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
155 const pool_conf_t &conf, const offsets_t &off,
156 const post_ops_t &post_ops) {
157 using namespace dnnl::impl::alg_kind;
158 kernel_ctx.set_data_type(conf.src_dt);
159
160 kernel_ctx.define_int("NDIMS", conf.ndims);
161 if (conf.num_batches > 1) {
162 kernel_ctx.define_int("MB", nstl::min(conf.mb_block_size, conf.mb));
163 } else {
164 kernel_ctx.define_int("MB", conf.mb);
165 }
166 kernel_ctx.define_int("MB_BLOCK_SIZE", conf.mb_block_size);
167 kernel_ctx.define_int("C_W_PADDING", conf.c_padded);
168 kernel_ctx.define_int("C_WO_PADDING", conf.c);
169 kernel_ctx.define_int("ID", conf.id);
170 kernel_ctx.define_int("IH", conf.ih);
171 kernel_ctx.define_int("IW", conf.iw);
172 kernel_ctx.define_int("OD", conf.od);
173 kernel_ctx.define_int("OH", conf.oh);
174 kernel_ctx.define_int("OW", conf.ow);
175 kernel_ctx.define_int("KD", conf.kd);
176 kernel_ctx.define_int("KH", conf.kh);
177 kernel_ctx.define_int("KW", conf.kw);
178 kernel_ctx.define_int("SD", conf.stride_d);
179 kernel_ctx.define_int("SH", conf.stride_h);
180 kernel_ctx.define_int("SW", conf.stride_w);
181 kernel_ctx.define_int("PD", conf.f_pad);
182 kernel_ctx.define_int("PH", conf.t_pad);
183 kernel_ctx.define_int("PW", conf.l_pad);
184 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
185 kernel_ctx.define_int("IS_TRAINING", conf.is_training);
186 kernel_ctx.define_int("IS_BWD", conf.is_backward);
187 kernel_ctx.define_int("IS_FWD", !conf.is_backward);
188
189 kernel_ctx.define_int("ALG_MAX", (conf.alg == pooling_max));
190 kernel_ctx.define_int(
191 "ALG_AVG_NP", (conf.alg == pooling_avg_exclude_padding));
192 kernel_ctx.define_int(
193 "ALG_AVG_P", (conf.alg == pooling_avg_include_padding));
194
195 kernel_ctx.define_int("VECT_DT_N", conf.vect_dt_n);
196 kernel_ctx.define_int("NVECT", conf.nvect);
197 kernel_ctx.define_int("USE_ONLY_C_BLOCK", conf.use_only_c_block);
198 kernel_ctx.define_int("USE_MB_C_BLOCK", conf.use_mb_c_block);
199 kernel_ctx.define_int("CHUNKS_PER_C_BLOCK", conf.chunks_per_c_block);
200 kernel_ctx.define_int("CHUNKS_PER_MB_BLOCK", conf.chunks_per_mb_block);
201 kernel_ctx.define_int("UNROLL_MB", conf.unroll_mb);
202
203 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
204
205 def_offsets(off.src_off, kernel_ctx, "SRC", conf.ndims);
206 def_offsets(off.dst_off, kernel_ctx, "DST", conf.ndims);
207
208 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
209
210 def_dispatch(kernel_ctx, conf.dispatch);
211
212 return status::success;
213}
214
215status_t gen9_pooling_fwd_t::pd_t::init_conf(engine_t *engine) {
216 return init_conf_common(conf, off, this, engine);
217}
218
219status_t gen9_pooling_fwd_t::pd_t::init_kernel_ctx(
220 compute::kernel_ctx_t &kernel_ctx) const {
221 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
222}
223
224status_t gen9_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
225
226 status_t status = status::success;
227
228 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
229 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
230 auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE);
231
232 compute::kernel_arg_list_t arg_list;
233 arg_list.set(0, src);
234 arg_list.set(1, ws);
235 arg_list.set(2, dst);
236 append_post_ops_to_arg_list(ctx, arg_list, 4, pd()->attr()->post_ops_);
237
238 auto nd_range = pd()->conf.dispatch.nd_range();
239
240 int num_batches = pd()->conf.num_batches;
241 for (int batch_iter = 0; batch_iter < num_batches; batch_iter++) {
242 arg_list.set(3, batch_iter);
243 status = parallel_for(ctx, nd_range, kernel_, arg_list);
244 if (status != status::success) return status;
245 }
246 return status;
247}
248
249status_t gen9_pooling_bwd_t::pd_t::init_conf(engine_t *engine) {
250 return init_conf_common(conf, off, this, engine);
251}
252
253status_t gen9_pooling_bwd_t::pd_t::init_kernel_ctx(
254 compute::kernel_ctx_t &kernel_ctx) const {
255 return init_kernel_ctx_common(kernel_ctx, conf, off, attr()->post_ops_);
256}
257
258status_t gen9_pooling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
259
260 status_t status = status::success;
261 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
262 CHECK(status);
263 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
264 auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE);
265
266 compute::kernel_arg_list_t arg_list;
267 arg_list.set(0, diff_src);
268 arg_list.set(1, ws);
269 arg_list.set(2, diff_dst);
270
271 auto nd_range = pd()->conf.dispatch.nd_range();
272
273 status = parallel_for(ctx, nd_range, kernel_, arg_list);
274
275 return status;
276}
277
278} // namespace ocl
279} // namespace gpu
280} // namespace impl
281} // namespace dnnl
282