1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gen9_batch_normalization.hpp"
18#include "gpu/ocl/ocl_utils.hpp"
19
20using namespace dnnl::impl::memory_tracking::names;
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace ocl {
26
27bool use_fused_atomics_reduction(bnorm_conf_t &conf, engine_t *engine) {
28 // Currently the fused atomics reduction is targeting to PVC only.
29 // Heuristics experimentally selected, based on PVC perf data
30 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
31 auto gpu_arch = compute_engine->device_info()->gpu_arch();
32 const size_t sp = conf.mb * conf.id * conf.ih * conf.iw;
33 return gpu_arch >= compute::gpu_arch_t::xe_hpc && conf.ic % 16 == 0
34 && sp / conf.ic > 40;
35}
36
37inline float get_ss_utilization(int max_ss, const size_t *gws, size_t *lws) {
38 const size_t gws_size = gws[0] * gws[1] * gws[2];
39 const size_t lws_size = lws[0] * lws[1] * lws[2];
40 const size_t used_ss = utils::div_up(gws_size, lws_size);
41 return (float)used_ss / max_ss;
42}
43
44// Local group size adjustment.
45void adjust_lws_calc_kernel(bnorm_conf_t &conf, engine_t *engine) {
46 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
47 auto eu_count = compute_engine->device_info()->eu_count();
48 auto max_lws = compute_engine->device_info()->max_wg_size();
49 auto eus_per_ss = compute_engine->device_info()->max_eus_per_wg();
50 const int max_ss = utils::div_up(eu_count, eus_per_ss);
51
52 auto generated_nd = conf.dispatch_calc_stat.nd_range();
53 const size_t *base_gws = generated_nd.global_range();
54 const size_t *base_lws = generated_nd.local_range();
55
56 size_t tuned_lws[3];
57 tuned_lws[0] = 16; // Assuming IC is dim 0
58 tuned_lws[1] = base_lws[1];
59 tuned_lws[2] = base_lws[2];
60
61 // The search is based on subslice utilization which calculated as the ratio
62 // used_subslices / max_available_subslices.
63
64 size_t best_lws1 = 1, curr_lws1 = 1;
65 float best_ss_utilization = 0.0f, curr_ss_utilization;
66 const int ss_util_limit = 2; // experimentally selected
67
68 while (tuned_lws[0] * curr_lws1 * tuned_lws[2] <= (size_t)max_lws
69 && curr_lws1 <= base_gws[1]) {
70 if (base_gws[1] % curr_lws1) {
71 curr_lws1++;
72 continue;
73 }
74 tuned_lws[1] = curr_lws1;
75 curr_ss_utilization = get_ss_utilization(max_ss, base_gws, tuned_lws);
76 if (curr_ss_utilization > best_ss_utilization
77 && curr_ss_utilization < (float)ss_util_limit) {
78 best_ss_utilization = curr_ss_utilization;
79 best_lws1 = curr_lws1;
80 }
81 curr_lws1++;
82 }
83 tuned_lws[1] = best_lws1;
84
85 conf.dispatch_calc_stat.set_lws(tuned_lws);
86}
87
88int get_nhwc_ic_block(int ic, int max_vect_size = 8) {
89 const int nblocks = ic / (max_vect_size * 16);
90 return nblocks < 2 || (ic / nblocks) % 16 ? ic : ic / nblocks;
91}
92int get_nhwc_vect_size(int ic, int simd = 16) {
93 int vect_size = 8;
94 while (true) {
95 if (ic / (vect_size * simd)) return vect_size;
96 vect_size /= 2;
97 }
98 return 1;
99}
100
101int get_nhwc_sp_block_size(
102 int sp, int ic_dim, int eu_count, int threads_per_eu, int simd = 16) {
103
104 float efficiency_thr = 0.0f;
105 float efficiency_peak_eu_thr = 0.0f;
106 int block_size_thr = 1;
107 int block_size_peak_eu_thr = 1;
108 int curr_block_size = sp;
109 int nthr_mul = 1;
110 const int ic_nsg = ic_dim / simd; // number of subgroups by ic dim
111
112 // The search is based on threads wave efficiency.
113 // Higher priority for cases with peak EUs utilization.
114 while (nthr_mul <= 32) {
115 const int nthr = nthr_mul * eu_count;
116 curr_block_size = utils::div_up(sp * ic_nsg, nthr);
117 const int nblock = utils::div_up(sp, curr_block_size);
118 const int nthr_gen = nblock * ic_nsg;
119
120 const float curr_efficiency_eus
121 = (float)nthr_gen / utils::rnd_up(nthr_gen, eu_count);
122 const float curr_efficiency_thr = (float)nthr_gen
123 / utils::rnd_up(nthr_gen, eu_count * threads_per_eu);
124
125 if (curr_efficiency_thr > efficiency_thr) {
126 efficiency_thr = curr_efficiency_thr;
127 block_size_thr = curr_block_size;
128 }
129 if (curr_efficiency_eus == 1
130 && curr_efficiency_thr > efficiency_peak_eu_thr) {
131 efficiency_peak_eu_thr = curr_efficiency_thr;
132 block_size_peak_eu_thr = curr_block_size;
133 }
134 nthr_mul++;
135 }
136 if (efficiency_peak_eu_thr > 0.0f) return block_size_peak_eu_thr;
137 return block_size_thr;
138}
139
140int get_block_size(bool is_backward, int hw_threads, int nn, int ic,
141 int work_size, int simd = 16) {
142 int block_size = 256;
143 float thread_efficiency = 0;
144 int hw_thread_mult = hw_threads;
145 const int align_size = is_backward ? 8 : 16;
146 while (true) {
147 const int nof_blocks
148 = nstl::max(utils::rnd_dn(hw_thread_mult * simd, ic) / ic, 1);
149 const int min_block_size
150 = utils::rnd_up(work_size, nof_blocks) / nof_blocks;
151 const int curr_block_size = utils::rnd_up(min_block_size, align_size);
152 const int nof_blocks_generated
153 = utils::rnd_up(work_size, curr_block_size) / curr_block_size;
154 const int threads_generated = nof_blocks_generated * ic / simd;
155 const float curr_thread_efficiency = float(threads_generated * nn)
156 / float(utils::rnd_up(threads_generated * nn, hw_threads));
157 if (curr_thread_efficiency > thread_efficiency) {
158 thread_efficiency = curr_thread_efficiency;
159 block_size = curr_block_size;
160 }
161 if (curr_thread_efficiency == 1.0 || curr_block_size <= 256) { break; }
162 hw_thread_mult += hw_threads;
163 }
164 return block_size;
165}
166
167static status_t init_conf_common(bnorm_conf_t &conf, offsets_t &off,
168 const batch_normalization_pd_t *pd, engine_t *engine) {
169 using namespace dnnl::impl::format_tag;
170
171 const batch_normalization_desc_t &bd = *pd->desc();
172 const memory_desc_wrapper data_mdw(
173 pd->is_fwd() ? pd->src_md() : pd->diff_src_md());
174 const int ndims = data_mdw.ndims();
175
176 conf.data_type = data_mdw.data_type();
177
178 conf.ndims = ndims;
179 conf.mb = data_mdw.dims()[0];
180
181 conf.ic = data_mdw.dims()[1];
182 conf.id = (ndims == 5) ? data_mdw.dims()[2] : 1;
183 conf.ih = (ndims == 3) ? 1 : data_mdw.dims()[ndims - 2];
184 conf.iw = data_mdw.dims()[ndims - 1];
185
186 conf.is_forward = pd->is_fwd();
187 conf.is_backward = !pd->is_fwd();
188
189 conf.use_scale = pd->use_scale();
190 conf.use_shift = pd->use_shift();
191 conf.save_stats = pd->is_training();
192 conf.is_training = pd->is_training();
193 conf.fuse_norm_add_relu = pd->fuse_norm_add_relu();
194 conf.fuse_norm_relu = pd->fuse_norm_relu() || pd->fuse_norm_add_relu();
195 conf.calculate_stats = !pd->stats_is_src();
196 conf.with_relu = pd->with_relu_post_op();
197 conf.eps = bd.batch_norm_epsilon;
198 conf.calculate_diff_stats = !pd->use_global_stats();
199 conf.diff_scale = (pd->use_scale() && bd.prop_kind == prop_kind::backward);
200 conf.diff_shift = (pd->use_shift() && bd.prop_kind == prop_kind::backward);
201
202 set_offsets(data_mdw, off.src_off);
203
204 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
205 auto gpu_arch = compute_engine->device_info()->gpu_arch();
206
207 conf.mb_block = 1;
208
209 const bool has_padding = !data_mdw.is_dense();
210 const bool is_blocked_16c
211 = data_mdw.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c);
212 const bool is_blocked_16n16c
213 = data_mdw.matches_one_of_tag(NCw16n16c, NChw16n16c, NCdhw16n16c);
214 const bool is_blocked_32n16c
215 = data_mdw.matches_one_of_tag(NCw32n16c, NChw32n16c, NCdhw32n16c);
216 const bool is_nhwc
217 = conf.ic % 8 == 0 && data_mdw.matches_one_of_tag(nwc, nhwc, ndhwc);
218
219 // Due to intel_sub_group_write_uc requires 16-bytes alignment,
220 // IC div by 8 tail processing is not applicable to fuse_norm_relu
221 // and char data type.
222 if (conf.ic % 8 == 0 && conf.ic % 16
223 && (conf.fuse_norm_relu || conf.data_type == data_type::s8))
224 return status::unimplemented;
225 // IC tail processing performnce boost is not obvious on arch < xe_hpc
226 if (conf.ic % 8 == 0 && conf.ic % 16
227 && gpu_arch < compute::gpu_arch_t::xe_hpc)
228 return status::unimplemented;
229
230 conf.use_stats_one_pass = experimental::use_bnorm_stats_one_pass();
231
232 // IC tail processing is not implemented yet for one pass algorithm
233 // TODO: implement it, possible perf boost could be ~ 2x
234 if (conf.ic % 8 == 0 && conf.ic % 16 && conf.use_stats_one_pass)
235 conf.use_stats_one_pass = false;
236
237 conf.use_nhwc = is_nhwc;
238
239 if (has_padding
240 || !(is_blocked_16c || is_blocked_16n16c || is_blocked_32n16c
241 || is_nhwc))
242 return status::unimplemented;
243
244 // IC tail processing is not implemented yet for NHWC optimized kernels
245 // TODO: implement it
246 // The use of NHWC optimized kernels
247 // is limited by XeHPG+ due to performance reasons
248 conf.nhwc_optimized = conf.ic % 16 == 0
249 && data_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
250 && gpu_arch >= compute::gpu_arch_t::xe_hpg;
251
252 conf.use_fused_atomics_reduction
253 = use_fused_atomics_reduction(conf, engine);
254
255 if (conf.nhwc_optimized) {
256 conf.ic_block = get_nhwc_ic_block(utils::rnd_up(conf.ic, 16));
257 } else {
258 conf.ic_block = 16;
259 }
260
261 conf.mb_block = is_blocked_32n16c ? 32 : is_blocked_16n16c ? 16 : 1;
262
263 if (is_nhwc) {
264 // reshape to xc
265 conf.nn = 1;
266 conf.sp = conf.mb * conf.id * conf.ih * conf.iw;
267 } else {
268 // reshape to nCx16c
269 conf.nn = conf.mb / conf.mb_block;
270 conf.sp = conf.id * conf.ih * conf.iw * conf.mb_block;
271 }
272
273 // The case IC==8 requires spacial dim to be even because of using one
274 // block read/write operation for 2 spacial rows at once
275 if (is_nhwc && conf.ic == 8 && conf.sp % 2) return status::unimplemented;
276
277 conf.calc_stat_ic = conf.nhwc_optimized
278 ? utils::div_up(conf.ic, conf.ic_block) * 16
279 : utils::rnd_up(conf.ic, 16);
280
281 auto eu_count = compute_engine->device_info()->eu_count();
282 auto threads_per_eu
283 = compute::device_info_t::threads_per_eu(gpu_arch, false);
284 const int max_sp_block_size = get_block_size(conf.is_backward, eu_count,
285 conf.nn, utils::rnd_up(conf.ic, 16), conf.sp);
286
287 if (conf.nn == 1)
288 conf.stat_sp_block = conf.nhwc_optimized
289 ? get_nhwc_sp_block_size(
290 conf.sp, conf.calc_stat_ic, eu_count, threads_per_eu)
291 : max_sp_block_size;
292 else
293 conf.stat_sp_block
294 = nstl::min(utils::rnd_up(conf.sp, 16), max_sp_block_size);
295
296 conf.stat_sp_nblocks
297 = utils::rnd_up(conf.sp, conf.stat_sp_block) / conf.stat_sp_block;
298 conf.stat_sp_tail
299 = utils::rnd_dn(conf.sp, conf.stat_sp_block) / conf.stat_sp_block;
300
301 conf.reduce_stat_nblocks = conf.nn * conf.stat_sp_nblocks;
302
303 conf.vect_size
304 = conf.nhwc_optimized ? get_nhwc_vect_size(conf.ic_block) : 8;
305
306 conf.dispatch_calc_stat = compute_engine->create_dispatch();
307 conf.dispatch_calc_stat.define_dim("STAT_MB", 0, conf.nn);
308 conf.dispatch_calc_stat.define_dim("STAT_SP", 1, conf.stat_sp_nblocks);
309 conf.dispatch_calc_stat.define_dim_with_nesting_level(
310 "STAT_IC", 1, conf.calc_stat_ic);
311 CHECK(conf.dispatch_calc_stat.vectorize_dim("STAT_IC", 16));
312 conf.dispatch_calc_stat.set_kernel_attr_suffix("CALC");
313 conf.dispatch_calc_stat.generate();
314 if (conf.use_fused_atomics_reduction)
315 adjust_lws_calc_kernel(conf, compute_engine);
316
317 conf.dispatch_reduce_stat = compute_engine->create_dispatch();
318 int reduce_sub_group_count = 1;
319 while (conf.reduce_stat_nblocks % (2 * reduce_sub_group_count) == 0
320 && 2 * reduce_sub_group_count * 16 <= 256) {
321 reduce_sub_group_count = reduce_sub_group_count * 2;
322 }
323 conf.stat_ic = reduce_sub_group_count * 16;
324 conf.dispatch_reduce_stat.define_dim("REDUCE_STAT_IC", 0, conf.stat_ic);
325 conf.dispatch_reduce_stat.define_dim(
326 "REDUCE_IC_GROUP", 1, utils::rnd_up(conf.ic, 16) / 16);
327 CHECK(conf.dispatch_reduce_stat.vectorize_dim("REDUCE_STAT_IC", 16));
328 conf.dispatch_reduce_stat.set_kernel_attr_suffix("REDUCE");
329 conf.dispatch_reduce_stat.generate();
330
331 const int sp_pad = utils::rnd_up(conf.sp, conf.vect_size);
332 conf.sp_tail = utils::rnd_dn(conf.sp, conf.vect_size);
333
334 conf.dispatch = compute_engine->create_dispatch(data_mdw.md_);
335 conf.dispatch.define_dim("MB", 0, conf.nn);
336 conf.dispatch.define_dim("SP", 1,
337 conf.nhwc_optimized ? conf.stat_sp_nblocks
338 : sp_pad / conf.vect_size);
339 conf.dispatch.define_dim("IC", 2, conf.calc_stat_ic);
340
341 CHECK(conf.dispatch.vectorize_dim("IC", 16));
342 conf.dispatch.generate();
343
344 conf.dispatch_reduce_aux = compute_engine->create_dispatch(data_mdw.md_);
345 conf.dispatch_reduce_aux.define_dim("IC_AUX", 0, conf.ic);
346 conf.dispatch_reduce_aux.set_kernel_attr_suffix("AUX");
347 conf.dispatch_reduce_aux.generate();
348
349 return status::success;
350}
351
352static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
353 const bnorm_conf_t &conf, const offsets_t &off) {
354 kernel_ctx.set_data_type(conf.data_type);
355
356 kernel_ctx.define_int("NDIMS", conf.ndims);
357 kernel_ctx.define_int("MB", conf.mb);
358 kernel_ctx.define_int("IC", conf.ic);
359 kernel_ctx.define_int("IC16", utils::rnd_up(conf.ic, 16));
360 kernel_ctx.define_int("ID", conf.id);
361 kernel_ctx.define_int("IH", conf.ih);
362 kernel_ctx.define_int("IW", conf.iw);
363 kernel_ctx.define_int("MB_BLOCK", conf.mb_block);
364 kernel_ctx.define_int("IC_BLOCK", conf.ic_block);
365
366 kernel_ctx.define_int("USE_NHWC", conf.use_nhwc);
367 kernel_ctx.define_int("SP", conf.sp);
368 kernel_ctx.define_int("SP_TAIL", conf.sp_tail);
369 kernel_ctx.define_int("VECT_SIZE", conf.vect_size);
370
371 kernel_ctx.define_int("STAT_SP_BLOCK", conf.stat_sp_block);
372 kernel_ctx.define_int("STAT_SP_NBLOCKS", conf.stat_sp_nblocks);
373 kernel_ctx.define_int("STAT_SP_TAIL", conf.stat_sp_tail);
374 kernel_ctx.define_int("REDUCE_STAT_NBLOCKS", conf.reduce_stat_nblocks);
375
376 if (conf.is_forward)
377 kernel_ctx.define_int("IS_FWD", 1);
378 else if (conf.is_backward)
379 kernel_ctx.define_int("IS_BWD", 1);
380
381 kernel_ctx.define_int("WITH_RELU", conf.with_relu);
382 kernel_ctx.define_int("SAVE_STATS", conf.save_stats);
383 kernel_ctx.define_int("IS_TRAINING", conf.is_training);
384 kernel_ctx.define_int("FUSE_BN_RELU", conf.fuse_norm_relu);
385 kernel_ctx.define_int("FUSE_BN_ADD_RELU", conf.fuse_norm_add_relu);
386 kernel_ctx.define_int("CALCULATE_STATS", conf.calculate_stats);
387 kernel_ctx.define_int("USE_SCALE", conf.use_scale);
388 kernel_ctx.define_int("USE_SHIFT", conf.use_shift);
389 kernel_ctx.define_int("CALCULATE_DIFF_STATS", conf.calculate_diff_stats);
390 kernel_ctx.define_int("DIFF_SCALE", conf.diff_scale);
391 kernel_ctx.define_int("DIFF_SHIFT", conf.diff_shift);
392 kernel_ctx.define_int("REDUCE_IC_SUB_GROUPS", conf.stat_ic / 16);
393 kernel_ctx.define_int("USE_STATS_ONE_PASS", conf.use_stats_one_pass);
394 kernel_ctx.define_int("NHWC_OPTIMIZED", conf.nhwc_optimized);
395 kernel_ctx.define_int(
396 "FUSED_ATOMICS_REDUCTION", conf.use_fused_atomics_reduction);
397
398 kernel_ctx.add_option("-cl-std=CL2.0");
399 if (conf.data_type == data_type::s8)
400 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
401
402 def_offsets(off.src_off, kernel_ctx, "SRC", conf.ndims);
403
404 def_dispatch(kernel_ctx, conf.dispatch_calc_stat);
405 def_dispatch(kernel_ctx, conf.dispatch_reduce_stat);
406 def_dispatch(kernel_ctx, conf.dispatch_reduce_aux);
407 def_dispatch(kernel_ctx, conf.dispatch);
408
409 return status::success;
410}
411
412status_t gen9_batch_normalization_fwd_t::pd_t::init_conf(engine_t *engine) {
413 return init_conf_common(conf, off, this, engine);
414}
415
416status_t gen9_batch_normalization_fwd_t::pd_t::init_kernel_ctx(
417 compute::kernel_ctx_t &kernel_ctx) const {
418 return init_kernel_ctx_common(kernel_ctx, conf, off);
419}
420
421void gen9_batch_normalization_fwd_t::pd_t::init_scratchpad() {
422 if (conf.calculate_stats) {
423 size_t size_coeff = sizeof(double) / sizeof(float);
424 size_t size = 2 * size_coeff * conf.reduce_stat_nblocks
425 * utils::rnd_up(conf.ic, 16);
426
427 auto scratchpad = scratchpad_registry().registrar();
428 scratchpad.book(key_bnorm_reduction, size,
429 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
430 if (!conf.save_stats) {
431 scratchpad.book(key_bnorm_tmp_mean, conf.ic,
432 types::data_type_size(data_type::f32),
433 OCL_BUFFER_ALIGNMENT);
434 scratchpad.book(key_bnorm_tmp_var, conf.ic,
435 types::data_type_size(data_type::f32),
436 OCL_BUFFER_ALIGNMENT);
437 }
438 }
439}
440
441status_t gen9_batch_normalization_fwd_t::execute_forward(
442 const exec_ctx_t &ctx) const {
443
444 status_t status = status::success;
445 const auto &conf = pd()->conf;
446
447 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
448 auto &src_add = CTX_IN_STORAGE(DNNL_ARG_SRC_1);
449
450 auto &mean_ = pd()->stats_is_src()
451 ? CTX_IN_STORAGE(DNNL_ARG_MEAN)
452 : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_MEAN, status);
453 CHECK(status);
454
455 auto &variance_ = pd()->stats_is_src()
456 ? CTX_IN_STORAGE(DNNL_ARG_VARIANCE)
457 : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_VARIANCE, status);
458 CHECK(status);
459
460 auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE);
461 auto &shift = CTX_IN_STORAGE(DNNL_ARG_SHIFT);
462
463 auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status);
464 CHECK(status);
465 auto &ws = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_WORKSPACE, status);
466 CHECK(status);
467
468 std::unique_ptr<memory_storage_t> temp_reduce;
469 std::unique_ptr<memory_storage_t> tmp_mean;
470 std::unique_ptr<memory_storage_t> tmp_variance;
471 if (conf.calculate_stats) {
472 temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage(
473 key_bnorm_reduction);
474
475 if (!conf.save_stats) {
476 tmp_mean = ctx.get_scratchpad_grantor().get_memory_storage(
477 key_bnorm_tmp_mean);
478 tmp_variance = ctx.get_scratchpad_grantor().get_memory_storage(
479 key_bnorm_tmp_var);
480 }
481 }
482
483 auto &mean = (conf.calculate_stats && !conf.save_stats) ? *tmp_mean : mean_;
484 auto &variance = (conf.calculate_stats && !conf.save_stats) ? *tmp_variance
485 : variance_;
486
487 if (conf.calculate_stats && conf.use_fused_atomics_reduction) {
488 // Atomics-based reduction requires zeroing mean and variance
489 compute::kernel_arg_list_t arg_list;
490 arg_list.set(0, mean);
491 arg_list.set(1, variance);
492
493 auto nd_range = conf.dispatch_reduce_aux.nd_range();
494 status = parallel_for(ctx, nd_range, reduce_init_kernel_, arg_list);
495 if (status != status::success) return status;
496 }
497
498 if (conf.calculate_stats && !conf.use_stats_one_pass) {
499 compute::kernel_arg_list_t calc_mean_arg_list;
500 calc_mean_arg_list.set(0, src);
501 calc_mean_arg_list.set(1, *temp_reduce);
502 calc_mean_arg_list.set(2, mean);
503
504 auto nd_range_calc_mean = conf.dispatch_calc_stat.nd_range();
505
506 status = parallel_for(ctx, nd_range_calc_mean, calculate_mean_kernel_,
507 calc_mean_arg_list);
508 if (status != status::success) return status;
509
510 if (conf.use_fused_atomics_reduction) {
511 compute::kernel_arg_list_t arg_list;
512 arg_list.set(0, mean);
513 auto nd_range = conf.dispatch_reduce_aux.nd_range();
514 status = parallel_for(
515 ctx, nd_range, reduce_final_kernel_, arg_list);
516 if (status != status::success) return status;
517 } else {
518 compute::kernel_arg_list_t reduce_mean_arg_list;
519 reduce_mean_arg_list.set(0, *temp_reduce);
520 reduce_mean_arg_list.set(1, mean);
521
522 auto nd_range_reduce_mean = conf.dispatch_reduce_stat.nd_range();
523
524 status = parallel_for(ctx, nd_range_reduce_mean,
525 reduce_mean_kernel_, reduce_mean_arg_list);
526 if (status != status::success) return status;
527 }
528
529 compute::kernel_arg_list_t calc_var_arg_list;
530 calc_var_arg_list.set(0, src);
531 calc_var_arg_list.set(1, mean);
532 calc_var_arg_list.set(2, *temp_reduce);
533 calc_var_arg_list.set(3, variance);
534
535 auto nd_range_calc_var = conf.dispatch_calc_stat.nd_range();
536
537 status = parallel_for(ctx, nd_range_calc_var,
538 calculate_variance_kernel_, calc_var_arg_list);
539 if (status != status::success) return status;
540
541 if (conf.use_fused_atomics_reduction) {
542 compute::kernel_arg_list_t arg_list;
543 arg_list.set(0, variance);
544 auto nd_range = conf.dispatch_reduce_aux.nd_range();
545 status = parallel_for(
546 ctx, nd_range, reduce_final_kernel_, arg_list);
547 if (status != status::success) return status;
548 } else {
549 compute::kernel_arg_list_t reduce_var_arg_list;
550 reduce_var_arg_list.set(0, *temp_reduce);
551 reduce_var_arg_list.set(1, variance);
552
553 auto nd_range_reduce_var = conf.dispatch_reduce_stat.nd_range();
554
555 status = parallel_for(ctx, nd_range_reduce_var,
556 reduce_variance_kernel_, reduce_var_arg_list);
557 if (status != status::success) return status;
558 }
559 }
560 if (conf.calculate_stats && conf.use_stats_one_pass) {
561 compute::kernel_arg_list_t calc_mean_var_arg_list;
562 calc_mean_var_arg_list.set(0, src);
563 calc_mean_var_arg_list.set(1, *temp_reduce);
564 calc_mean_var_arg_list.set(2, mean);
565 calc_mean_var_arg_list.set(3, variance);
566
567 auto nd_range_calc_mean = conf.dispatch_calc_stat.nd_range();
568
569 status = parallel_for(ctx, nd_range_calc_mean,
570 calculate_mean_var_kernel_, calc_mean_var_arg_list);
571 if (status != status::success) return status;
572
573 if (conf.use_fused_atomics_reduction) {
574 compute::kernel_arg_list_t reduce_final_arg_list;
575 reduce_final_arg_list.set(0, mean);
576 reduce_final_arg_list.set(1, variance);
577 auto nd_range_reduce_final = conf.dispatch_reduce_aux.nd_range();
578
579 status = parallel_for(ctx, nd_range_reduce_final,
580 reduce_final_kernel_, reduce_final_arg_list);
581 if (status != status::success) return status;
582 } else {
583 compute::kernel_arg_list_t reduce_mean_var_arg_list;
584 reduce_mean_var_arg_list.set(0, *temp_reduce);
585 reduce_mean_var_arg_list.set(1, mean);
586 reduce_mean_var_arg_list.set(2, variance);
587
588 auto nd_range_reduce_mean = conf.dispatch_reduce_stat.nd_range();
589
590 status = parallel_for(ctx, nd_range_reduce_mean,
591 reduce_mean_var_kernel_, reduce_mean_var_arg_list);
592 if (status != status::success) return status;
593 }
594 }
595 compute::kernel_arg_list_t arg_list;
596 arg_list.set(0, src);
597 arg_list.set(1, mean);
598 arg_list.set(2, variance);
599 arg_list.set(3, dst);
600 arg_list.set(4, scale);
601 arg_list.set(5, shift);
602 arg_list.set(6, ws);
603 arg_list.set(7, conf.eps);
604 arg_list.set(8, src_add);
605
606 auto nd_range = conf.dispatch.nd_range();
607
608 status = parallel_for(ctx, nd_range, kernel_, arg_list);
609 return status;
610}
611
612status_t gen9_batch_normalization_bwd_t::pd_t::init_conf(engine_t *engine) {
613 return init_conf_common(conf, off, this, engine);
614}
615
616status_t gen9_batch_normalization_bwd_t::pd_t::init_kernel_ctx(
617 compute::kernel_ctx_t &kernel_ctx) const {
618 return init_kernel_ctx_common(kernel_ctx, conf, off);
619}
620
621void gen9_batch_normalization_bwd_t::pd_t::init_scratchpad() {
622 size_t size
623 = 2 * utils::rnd_up(conf.ic, 16) * (1 + conf.reduce_stat_nblocks);
624 auto scratchpad = scratchpad_registry().registrar();
625 scratchpad.book(key_bnorm_reduction, size,
626 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
627}
628
629status_t gen9_batch_normalization_bwd_t::execute_backward(
630 const exec_ctx_t &ctx) const {
631
632 status_t status = status::success;
633
634 const auto &conf = pd()->conf;
635
636 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
637 auto &mean = CTX_IN_STORAGE(DNNL_ARG_MEAN);
638 auto &variance = CTX_IN_STORAGE(DNNL_ARG_VARIANCE);
639 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
640 auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE);
641 auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE);
642
643 auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status);
644 CHECK(status);
645 auto &diff_src_add = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC_1, status);
646 CHECK(status);
647
648 auto &diff_scale_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SCALE, status);
649 CHECK(status);
650 auto &diff_shift_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SHIFT, status);
651 CHECK(status);
652
653 std::unique_ptr<memory_storage_t> temp_reduce;
654 temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage(
655 key_bnorm_reduction);
656
657 auto &diff_scale = !conf.diff_scale ? *temp_reduce : diff_scale_;
658 auto &diff_shift = !conf.diff_shift ? *temp_reduce : diff_shift_;
659
660 if (conf.use_fused_atomics_reduction) {
661 compute::kernel_arg_list_t reduce_init_arg_list;
662 reduce_init_arg_list.set(0, diff_scale);
663 reduce_init_arg_list.set(1, diff_shift);
664
665 auto nd_range_reduce_init = conf.dispatch_reduce_aux.nd_range();
666 status = parallel_for(ctx, nd_range_reduce_init, reduce_init_kernel_,
667 reduce_init_arg_list);
668 if (status != status::success) return status;
669 }
670
671 compute::kernel_arg_list_t calc_stats_arg_list;
672 calc_stats_arg_list.set(0, src);
673 calc_stats_arg_list.set(1, mean);
674 calc_stats_arg_list.set(2, diff_dst);
675 calc_stats_arg_list.set(3, ws);
676 calc_stats_arg_list.set(4, *temp_reduce);
677 calc_stats_arg_list.set(5, diff_scale);
678 calc_stats_arg_list.set(6, diff_shift);
679
680 auto nd_range = conf.dispatch_calc_stat.nd_range();
681 status = parallel_for(
682 ctx, nd_range, calculate_stats_kernel_, calc_stats_arg_list);
683 if (status != status::success) return status;
684
685 if (conf.use_fused_atomics_reduction) {
686 compute::kernel_arg_list_t arg_list;
687 arg_list.set(0, diff_scale);
688 arg_list.set(1, variance);
689 arg_list.set(2, conf.eps);
690 auto nd_range = conf.dispatch_reduce_aux.nd_range();
691 status = parallel_for(ctx, nd_range, reduce_final_kernel_, arg_list);
692 if (status != status::success) return status;
693 } else {
694 compute::kernel_arg_list_t reduce_stats_arg_list;
695 reduce_stats_arg_list.set(0, *temp_reduce);
696 reduce_stats_arg_list.set(1, diff_scale);
697 reduce_stats_arg_list.set(2, diff_shift);
698 reduce_stats_arg_list.set(3, variance);
699 reduce_stats_arg_list.set(4, conf.eps);
700
701 auto nd_range_reduce_stat = conf.dispatch_reduce_stat.nd_range();
702 status = parallel_for(ctx, nd_range_reduce_stat, reduce_stats_kernel_,
703 reduce_stats_arg_list);
704 if (status != status::success) return status;
705 }
706
707 compute::kernel_arg_list_t arg_list;
708 arg_list.set(0, src);
709 arg_list.set(1, mean);
710 arg_list.set(2, variance);
711 arg_list.set(3, diff_dst);
712 arg_list.set(4, scale);
713 arg_list.set(5, ws);
714 arg_list.set(6, diff_src);
715 arg_list.set(7, diff_scale);
716 arg_list.set(8, diff_shift);
717 arg_list.set(9, conf.eps);
718 arg_list.set(10, diff_src_add);
719
720 nd_range = conf.dispatch.nd_range();
721 status = parallel_for(ctx, nd_range, bwd_kernel_, arg_list);
722
723 return status;
724}
725
726} // namespace ocl
727} // namespace gpu
728} // namespace impl
729} // namespace dnnl
730