1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_batch_normalization.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/math_utils.hpp"
22#include "common/scratchpad.hpp"
23#include "common/type_helpers.hpp"
24
25using namespace dnnl::impl::memory_tracking::names;
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace ocl {
31
32static status_t init_conf_common(bnorm_conf_t &conf, offsets_t &off,
33 const batch_normalization_pd_t *pd, engine_t *engine) {
34 using namespace dnnl::impl::format_tag;
35
36 const batch_normalization_desc_t &bd = *pd->desc();
37 const memory_desc_wrapper data_mdw(
38 pd->is_fwd() ? pd->src_md() : pd->diff_src_md());
39 const int ndims = data_mdw.ndims();
40
41 conf = utils::zero<decltype(conf)>();
42 conf.data_type = data_mdw.data_type();
43
44 conf.ndims = ndims;
45 conf.mb = data_mdw.dims()[0];
46
47 conf.ic = data_mdw.dims()[1];
48 conf.id = (ndims >= 5) ? data_mdw.dims()[ndims - 3] : 1;
49 conf.ih = (ndims >= 4) ? data_mdw.dims()[ndims - 2] : 1;
50 conf.iw = (ndims >= 3) ? data_mdw.dims()[ndims - 1] : 1;
51
52 conf.is_forward = pd->is_fwd();
53 conf.is_backward = !pd->is_fwd();
54
55 conf.use_scale = pd->use_scale();
56 conf.use_shift = pd->use_shift();
57 conf.save_stats = pd->is_training();
58 conf.is_training = pd->is_training();
59 conf.fuse_norm_relu = pd->fuse_norm_relu() || pd->fuse_norm_add_relu();
60 conf.fuse_norm_add_relu = pd->fuse_norm_add_relu();
61 conf.calculate_stats = !pd->stats_is_src();
62 conf.with_relu = pd->with_relu_post_op();
63 conf.eps = bd.batch_norm_epsilon;
64 conf.calculate_diff_stats = !pd->use_global_stats();
65 conf.diff_scale = (pd->use_scale() && bd.prop_kind == prop_kind::backward);
66 conf.diff_shift = (pd->use_shift() && bd.prop_kind == prop_kind::backward);
67
68 set_offsets(data_mdw, off.src_off);
69
70 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
71
72 conf.use_16mb_unroll = false;
73 conf.use_nhwc = false;
74 conf.mb_block = 1;
75 conf.ic_block = 1;
76
77 const bool has_padding = !data_mdw.is_dense();
78
79 if (!has_padding && conf.is_backward
80 && data_mdw.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c, NCw16n16c,
81 NChw16n16c, NCdhw16n16c)) {
82 conf.mb_block = data_mdw.matches_one_of_tag(
83 NCw16n16c, NChw16n16c, NCdhw16n16c)
84 ? 16
85 : 1;
86 conf.ic_block = 16;
87 conf.use_16mb_unroll = true;
88
89 const int max_stat_nblocks = 256;
90 int stat_mb_nblocks = conf.mb / conf.mb_block;
91 int stat_sp_nblocks = utils::max_div(conf.id * conf.ih * conf.iw,
92 nstl::max(1, max_stat_nblocks / stat_mb_nblocks));
93 assert(stat_mb_nblocks * stat_sp_nblocks <= max_stat_nblocks);
94
95 int stat_sp_block = conf.id * conf.ih * conf.iw / stat_sp_nblocks;
96
97 conf.reduce_stat_nblocks = stat_mb_nblocks * stat_sp_nblocks;
98
99 conf.dispatch_calc_stat = compute_engine->create_dispatch();
100 conf.dispatch_calc_stat.define_dim_with_nesting_level(
101 "STAT_SP", 2, conf.id * conf.ih * conf.iw, stat_sp_block);
102 conf.dispatch_calc_stat.define_dim_with_nesting_level(
103 "STAT_IC", 1, conf.ic);
104 conf.dispatch_calc_stat.define_dim_with_nesting_level(
105 "STAT_MB", 0, conf.mb, conf.mb_block);
106 CHECK(conf.dispatch_calc_stat.vectorize_dim("STAT_IC", 16));
107 conf.dispatch_calc_stat.set_kernel_attr_suffix("CALC");
108 conf.dispatch_calc_stat.generate();
109
110 conf.dispatch_reduce_stat = compute_engine->create_dispatch();
111 conf.dispatch_reduce_stat.define_dim("REDUCE_STAT_IC", conf.ic);
112 conf.dispatch_reduce_stat.set_kernel_attr_suffix("REDUCE");
113 conf.dispatch_reduce_stat.generate();
114
115 conf.dispatch = compute_engine->create_dispatch(data_mdw.md_);
116 conf.dispatch.define_dim("MB", 0, conf.mb, conf.mb_block);
117 conf.dispatch.define_dim("IC", 1, conf.ic);
118 conf.dispatch.define_dim("ID", nstl::max(1, ndims - 3), conf.id);
119 conf.dispatch.define_dim("IH", nstl::max(1, ndims - 2), conf.ih);
120 conf.dispatch.define_dim("IW", nstl::max(1, ndims - 1), conf.iw);
121 CHECK(conf.dispatch.vectorize_dim("IC", 16));
122 conf.dispatch.generate();
123 } else {
124 // Reference
125 conf.use_16mb_unroll = false;
126 conf.dispatch = compute_engine->create_dispatch(data_mdw.md_);
127 conf.dispatch.define_dim("MB", 0, conf.mb);
128 conf.dispatch.define_dim("IC", 1, conf.ic);
129 conf.dispatch.define_dim("ID", nstl::max(1, ndims - 3), conf.id);
130 conf.dispatch.define_dim("IH", nstl::max(1, ndims - 2), conf.ih);
131 conf.dispatch.define_dim("IW", nstl::max(1, ndims - 1), conf.iw);
132
133 conf.dispatch.generate();
134 if (conf.calculate_stats || conf.is_backward) {
135
136 conf.dispatch_calc_stat
137 = compute_engine->create_dispatch(data_mdw.md_);
138 int calc_dims[5];
139 auto &dims = data_mdw.dims();
140 calc_dims[0] = dims[0];
141 calc_dims[1] = dims[1];
142 calc_dims[2] = (ndims < 5) ? 1 : dims[ndims - 3];
143 calc_dims[3] = (ndims < 4) ? 1 : dims[ndims - 2];
144 calc_dims[4] = (ndims < 3) ? 1 : dims[ndims - 1];
145 int reduce_dim_idx = 0;
146 for (int i = 2; i < 5; i++) {
147 if (calc_dims[i] > calc_dims[reduce_dim_idx]) {
148 reduce_dim_idx = i;
149 }
150 }
151 conf.reduce_dim = calc_dims[reduce_dim_idx];
152 conf.reduce_dim_idx = reduce_dim_idx;
153 const std::string dim_names[5]
154 = {"STAT_MB", "STAT_IC", "STAT_ID", "STAT_IH", "STAT_IW"};
155 const std::string &reduce_dim_name = dim_names[reduce_dim_idx];
156
157 conf.vectorize_calc_stats = false;
158 conf.vect_size = 1;
159 conf.sub_group_size = 1;
160 int calc_dims_blocks[5] = {1, 1, 1, 1, 1};
161
162 // Translate reduce_dim_idx from being an index in calc_dims to dims array
163 const int base_reduce_dim_idx
164 = reduce_dim_idx == 0 ? 0 : reduce_dim_idx - (5 - ndims);
165 const int reduce_dim_stride
166 = data_mdw.blocking_desc().strides[base_reduce_dim_idx];
167 if (conf.is_forward && conf.reduce_dim % 16 == 0
168 && reduce_dim_stride == 1) {
169 // Calculations over reduce dimension will be splitted
170 // between work items in the single subgroup.
171 // Each item will read vector_size of elements at once.
172 conf.vectorize_calc_stats = true;
173 conf.sub_group_size = 16;
174
175 int vector_size = 8;
176 while (conf.reduce_dim % (conf.sub_group_size * vector_size)
177 != 0) {
178 vector_size /= 2;
179 }
180 conf.vect_size = vector_size;
181 calc_dims_blocks[reduce_dim_idx]
182 = conf.reduce_dim / conf.sub_group_size;
183 } else {
184 // Whole reduce dimension will be handled by single work item.
185 calc_dims[reduce_dim_idx] = 1;
186 }
187
188 conf.stat_ic = utils::array_product(calc_dims, 5);
189 conf.dispatch_calc_stat.define_dim(
190 dim_names[0], 0, calc_dims[0], calc_dims_blocks[0]);
191 conf.dispatch_calc_stat.define_dim(
192 dim_names[1], 1, calc_dims[1], calc_dims_blocks[1]);
193 conf.dispatch_calc_stat.define_dim(dim_names[2],
194 nstl::max(1, ndims - 3), calc_dims[2], calc_dims_blocks[2]);
195 conf.dispatch_calc_stat.define_dim(dim_names[3],
196 nstl::max(1, ndims - 2), calc_dims[3], calc_dims_blocks[3]);
197 conf.dispatch_calc_stat.define_dim(dim_names[4],
198 nstl::max(1, ndims - 1), calc_dims[4], calc_dims_blocks[4]);
199
200 conf.skip_reduce_stat = false;
201 if (conf.vectorize_calc_stats) {
202 CHECK(conf.dispatch_calc_stat.vectorize_dim(
203 reduce_dim_name, conf.sub_group_size));
204 if (conf.stat_ic == conf.reduce_dim * calc_dims[1]) {
205 // if there are only 2 dimensions greater than 1:
206 // IC and reduce_dim, calc phase of batchnorm will do
207 // whole reduction and reduce phase can be skipped
208 conf.skip_reduce_stat = true;
209 }
210 }
211
212 conf.dispatch_calc_stat.set_kernel_attr_suffix("CALC");
213 conf.dispatch_calc_stat.generate();
214
215 conf.dispatch_reduce_stat = compute_engine->create_dispatch();
216 conf.dispatch_reduce_stat.define_dim("REDUCE_STAT_IC", conf.ic);
217 conf.dispatch_reduce_stat.set_kernel_attr_suffix("REDUCE");
218 conf.dispatch_reduce_stat.generate();
219 }
220 }
221
222 return status::success;
223}
224
225static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
226 const bnorm_conf_t &conf, const offsets_t &off) {
227 kernel_ctx.set_data_type(conf.data_type);
228
229 kernel_ctx.define_int("NDIMS", conf.ndims);
230 kernel_ctx.define_int("MB", conf.mb);
231 kernel_ctx.define_int("IC", conf.ic);
232 kernel_ctx.define_int("ID", conf.id);
233 kernel_ctx.define_int("IH", conf.ih);
234 kernel_ctx.define_int("IW", conf.iw);
235 kernel_ctx.define_int("USE_16MB_UNROLL", conf.use_16mb_unroll);
236 kernel_ctx.define_int("USE_NHWC", conf.use_nhwc);
237 kernel_ctx.define_int("REDUCE_STAT_NBLOCKS", conf.reduce_stat_nblocks);
238 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
239 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
240
241 kernel_ctx.define_int("REDUCE_DIM_IDX", conf.reduce_dim_idx);
242 kernel_ctx.define_int("REDUCE_DIM", conf.reduce_dim);
243
244 if (conf.is_forward)
245 kernel_ctx.define_int("IS_FWD", 1);
246 else if (conf.is_backward)
247 kernel_ctx.define_int("IS_BWD", 1);
248
249 kernel_ctx.define_int("WITH_RELU", conf.with_relu);
250 kernel_ctx.define_int("SAVE_STATS", conf.save_stats);
251 kernel_ctx.define_int("IS_TRAINING", conf.is_training);
252 kernel_ctx.define_int("FUSE_BN_RELU", conf.fuse_norm_relu);
253 kernel_ctx.define_int("FUSE_BN_ADD_RELU", conf.fuse_norm_add_relu);
254 kernel_ctx.define_int("CALCULATE_STATS", conf.calculate_stats);
255 kernel_ctx.define_int("USE_SCALE", conf.use_scale);
256 kernel_ctx.define_int("USE_SHIFT", conf.use_shift);
257 kernel_ctx.define_int("CALCULATE_DIFF_STATS", conf.calculate_diff_stats);
258 kernel_ctx.define_int("DIFF_SCALE", conf.diff_scale);
259 kernel_ctx.define_int("DIFF_SHIFT", conf.diff_shift);
260 kernel_ctx.define_int("VECTORIZE_CALC_STATS", conf.vectorize_calc_stats);
261 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
262 kernel_ctx.define_int("VECT_SIZE", conf.vect_size);
263 kernel_ctx.define_int("SKIP_REDUCE_STATS", conf.skip_reduce_stat);
264
265 def_offsets(off.src_off, kernel_ctx, "SRC", conf.ndims);
266
267 if (conf.data_type == data_type::s8)
268 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
269
270 if (conf.calculate_stats || conf.is_backward) {
271 def_dispatch(kernel_ctx, conf.dispatch_calc_stat);
272 def_dispatch(kernel_ctx, conf.dispatch_reduce_stat);
273 }
274 def_dispatch(kernel_ctx, conf.dispatch);
275
276 return status::success;
277}
278
279status_t ref_batch_normalization_fwd_t::pd_t::init_conf(engine_t *engine) {
280 return init_conf_common(conf, off, this, engine);
281}
282
283status_t ref_batch_normalization_fwd_t::pd_t::init_kernel_ctx(
284 compute::kernel_ctx_t &kernel_ctx) const {
285 return init_kernel_ctx_common(kernel_ctx, conf, off);
286}
287
288void ref_batch_normalization_fwd_t::pd_t::init_scratchpad() {
289 if (conf.calculate_stats) {
290
291 size_t size = 2 * conf.stat_ic;
292
293 auto scratchpad = scratchpad_registry().registrar();
294 scratchpad.book(memory_tracking::names::key_bnorm_reduction, size,
295 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
296 }
297}
298
299status_t ref_batch_normalization_fwd_t::execute_forward(
300 const exec_ctx_t &ctx) const {
301
302 status_t status = status::success;
303 const auto &conf = pd()->conf;
304
305 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
306 auto &src_add = CTX_IN_STORAGE(DNNL_ARG_SRC_1);
307
308 auto &mean_ = pd()->stats_is_src()
309 ? CTX_IN_STORAGE(DNNL_ARG_MEAN)
310 : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_MEAN, status);
311 CHECK(status);
312
313 auto &variance_ = pd()->stats_is_src()
314 ? CTX_IN_STORAGE(DNNL_ARG_VARIANCE)
315 : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_VARIANCE, status);
316 CHECK(status);
317
318 auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE);
319 auto &shift = CTX_IN_STORAGE(DNNL_ARG_SHIFT);
320
321 auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status);
322 CHECK(status);
323 auto &ws = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_WORKSPACE, status);
324 CHECK(status);
325
326 auto *mean_ptr = &mean_;
327 auto *variance_ptr = &variance_;
328
329 std::unique_ptr<memory_storage_t> temp_reduce = nullptr;
330 if (conf.calculate_stats) {
331 if (!conf.skip_reduce_stat || !conf.save_stats) {
332 temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage(
333 key_bnorm_reduction);
334 }
335
336 if (!conf.save_stats) {
337 mean_ptr = temp_reduce.get();
338 variance_ptr = temp_reduce.get();
339 }
340 }
341
342 auto &mean = *mean_ptr;
343 auto &variance = *variance_ptr;
344
345 if (conf.calculate_stats) {
346 if (conf.skip_reduce_stat) {
347 compute::kernel_arg_list_t calc_var_arg_list;
348 calc_var_arg_list.set(0, src);
349 calc_var_arg_list.set(1, mean);
350 calc_var_arg_list.set(2, variance);
351
352 auto nd_range_calc_var = conf.dispatch_calc_stat.nd_range();
353
354 status = parallel_for(ctx, nd_range_calc_var,
355 calculate_mean_variance_kernel_, calc_var_arg_list);
356 if (status != status::success) return status;
357 } else {
358 compute::kernel_arg_list_t calc_mean_arg_list;
359 calc_mean_arg_list.set(0, src);
360 calc_mean_arg_list.set(1, *temp_reduce);
361
362 auto nd_range_calc_mean = conf.dispatch_calc_stat.nd_range();
363
364 status = parallel_for(ctx, nd_range_calc_mean,
365 calculate_mean_kernel_, calc_mean_arg_list);
366 if (status != status::success) return status;
367
368 compute::kernel_arg_list_t reduce_mean_arg_list;
369 reduce_mean_arg_list.set(0, *temp_reduce);
370 reduce_mean_arg_list.set(1, mean);
371
372 auto nd_range_reduce_mean = conf.dispatch_reduce_stat.nd_range();
373
374 status = parallel_for(ctx, nd_range_reduce_mean,
375 reduce_mean_kernel_, reduce_mean_arg_list);
376 if (status != status::success) return status;
377
378 compute::kernel_arg_list_t calc_var_arg_list;
379 calc_var_arg_list.set(0, src);
380 calc_var_arg_list.set(1, mean);
381 calc_var_arg_list.set(2, *temp_reduce);
382
383 auto nd_range_calc_var = conf.dispatch_calc_stat.nd_range();
384
385 status = parallel_for(ctx, nd_range_calc_var,
386 calculate_variance_kernel_, calc_var_arg_list);
387 if (status != status::success) return status;
388
389 compute::kernel_arg_list_t reduce_var_arg_list;
390 reduce_var_arg_list.set(0, *temp_reduce);
391 reduce_var_arg_list.set(1, variance);
392
393 auto nd_range_reduce_var = conf.dispatch_reduce_stat.nd_range();
394
395 status = parallel_for(ctx, nd_range_reduce_var,
396 reduce_variance_kernel_, reduce_var_arg_list);
397 if (status != status::success) return status;
398 }
399 }
400
401 compute::kernel_arg_list_t arg_list;
402 arg_list.set(0, src);
403 arg_list.set(1, mean);
404 arg_list.set(2, variance);
405 arg_list.set(3, dst);
406 arg_list.set(4, scale);
407 arg_list.set(5, shift);
408 arg_list.set(6, ws);
409 arg_list.set(7, conf.eps);
410 arg_list.set(8, src_add);
411
412 auto nd_range = conf.dispatch.nd_range();
413
414 status = parallel_for(ctx, nd_range, kernel_, arg_list);
415
416 return status;
417}
418
419status_t ref_batch_normalization_bwd_t::pd_t::init_conf(engine_t *engine) {
420 return init_conf_common(conf, off, this, engine);
421}
422
423status_t ref_batch_normalization_bwd_t::pd_t::init_kernel_ctx(
424 compute::kernel_ctx_t &kernel_ctx) const {
425 return init_kernel_ctx_common(kernel_ctx, conf, off);
426}
427
428void ref_batch_normalization_bwd_t::pd_t::init_scratchpad() {
429 size_t size;
430 if (conf.use_16mb_unroll) {
431 size = 2 * conf.reduce_stat_nblocks * conf.ic;
432 } else {
433 size = 2 * conf.stat_ic;
434 }
435
436 auto scratchpad = scratchpad_registry().registrar();
437 scratchpad.book(memory_tracking::names::key_bnorm_reduction, size,
438 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
439}
440
441status_t ref_batch_normalization_bwd_t::execute_backward(
442 const exec_ctx_t &ctx) const {
443
444 status_t status = status::success;
445 const auto &conf = pd()->conf;
446
447 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
448 auto &mean = CTX_IN_STORAGE(DNNL_ARG_MEAN);
449 auto &variance = CTX_IN_STORAGE(DNNL_ARG_VARIANCE);
450 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
451 auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE);
452 auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE);
453
454 auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status);
455 CHECK(status);
456 auto &diff_src_add = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC_1, status);
457 CHECK(status);
458 auto &diff_scale_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SCALE, status);
459 CHECK(status);
460 auto &diff_shift_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SHIFT, status);
461 CHECK(status);
462
463 std::unique_ptr<memory_storage_t> temp_reduce;
464 temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage(
465 key_bnorm_reduction);
466
467 compute::kernel_arg_list_t calc_stats_arg_list;
468 calc_stats_arg_list.set(0, src);
469 calc_stats_arg_list.set(1, mean);
470 calc_stats_arg_list.set(2, diff_dst);
471 calc_stats_arg_list.set(3, ws);
472 calc_stats_arg_list.set(4, *temp_reduce);
473
474 auto nd_range = conf.dispatch_calc_stat.nd_range();
475
476 status = parallel_for(
477 ctx, nd_range, calculate_stats_kernel_, calc_stats_arg_list);
478 if (status != status::success) return status;
479
480 auto &diff_scale = !conf.diff_scale ? *temp_reduce : diff_scale_;
481 auto &diff_shift = !conf.diff_shift ? *temp_reduce : diff_shift_;
482
483 compute::kernel_arg_list_t reduce_stats_arg_list;
484 reduce_stats_arg_list.set(0, *temp_reduce);
485 reduce_stats_arg_list.set(1, diff_scale);
486 reduce_stats_arg_list.set(2, diff_shift);
487 reduce_stats_arg_list.set(3, variance);
488 reduce_stats_arg_list.set(4, conf.eps);
489
490 auto nd_range_reduce_stat = conf.dispatch_reduce_stat.nd_range();
491
492 status = parallel_for(ctx, nd_range_reduce_stat, reduce_stats_kernel_,
493 reduce_stats_arg_list);
494 if (status != status::success) return status;
495
496 compute::kernel_arg_list_t arg_list;
497 arg_list.set(0, src);
498 arg_list.set(1, mean);
499 arg_list.set(2, variance);
500 arg_list.set(3, diff_dst);
501 arg_list.set(4, scale);
502 arg_list.set(5, ws);
503 arg_list.set(6, diff_src);
504 arg_list.set(7, diff_scale);
505 arg_list.set(8, diff_shift);
506 arg_list.set(9, conf.eps);
507 arg_list.set(10, diff_src_add);
508
509 nd_range = conf.dispatch.nd_range();
510
511 status = parallel_for(ctx, nd_range, kernel_, arg_list);
512
513 return status;
514}
515
516} // namespace ocl
517} // namespace gpu
518} // namespace impl
519} // namespace dnnl
520
521// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
522