1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_batch_normalization.hpp" |
18 | #include "gpu/ocl/ocl_utils.hpp" |
19 | |
20 | using namespace dnnl::impl::memory_tracking::names; |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace ocl { |
26 | |
27 | bool use_fused_atomics_reduction(bnorm_conf_t &conf, engine_t *engine) { |
28 | // Currently the fused atomics reduction is targeting to PVC only. |
29 | // Heuristics experimentally selected, based on PVC perf data |
30 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
31 | auto gpu_arch = compute_engine->device_info()->gpu_arch(); |
32 | const size_t sp = conf.mb * conf.id * conf.ih * conf.iw; |
33 | return gpu_arch >= compute::gpu_arch_t::xe_hpc && conf.ic % 16 == 0 |
34 | && sp / conf.ic > 40; |
35 | } |
36 | |
37 | inline float get_ss_utilization(int max_ss, const size_t *gws, size_t *lws) { |
38 | const size_t gws_size = gws[0] * gws[1] * gws[2]; |
39 | const size_t lws_size = lws[0] * lws[1] * lws[2]; |
40 | const size_t used_ss = utils::div_up(gws_size, lws_size); |
41 | return (float)used_ss / max_ss; |
42 | } |
43 | |
44 | // Local group size adjustment. |
45 | void adjust_lws_calc_kernel(bnorm_conf_t &conf, engine_t *engine) { |
46 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
47 | auto eu_count = compute_engine->device_info()->eu_count(); |
48 | auto max_lws = compute_engine->device_info()->max_wg_size(); |
49 | auto eus_per_ss = compute_engine->device_info()->max_eus_per_wg(); |
50 | const int max_ss = utils::div_up(eu_count, eus_per_ss); |
51 | |
52 | auto generated_nd = conf.dispatch_calc_stat.nd_range(); |
53 | const size_t *base_gws = generated_nd.global_range(); |
54 | const size_t *base_lws = generated_nd.local_range(); |
55 | |
56 | size_t tuned_lws[3]; |
57 | tuned_lws[0] = 16; // Assuming IC is dim 0 |
58 | tuned_lws[1] = base_lws[1]; |
59 | tuned_lws[2] = base_lws[2]; |
60 | |
61 | // The search is based on subslice utilization which calculated as the ratio |
62 | // used_subslices / max_available_subslices. |
63 | |
64 | size_t best_lws1 = 1, curr_lws1 = 1; |
65 | float best_ss_utilization = 0.0f, curr_ss_utilization; |
66 | const int ss_util_limit = 2; // experimentally selected |
67 | |
68 | while (tuned_lws[0] * curr_lws1 * tuned_lws[2] <= (size_t)max_lws |
69 | && curr_lws1 <= base_gws[1]) { |
70 | if (base_gws[1] % curr_lws1) { |
71 | curr_lws1++; |
72 | continue; |
73 | } |
74 | tuned_lws[1] = curr_lws1; |
75 | curr_ss_utilization = get_ss_utilization(max_ss, base_gws, tuned_lws); |
76 | if (curr_ss_utilization > best_ss_utilization |
77 | && curr_ss_utilization < (float)ss_util_limit) { |
78 | best_ss_utilization = curr_ss_utilization; |
79 | best_lws1 = curr_lws1; |
80 | } |
81 | curr_lws1++; |
82 | } |
83 | tuned_lws[1] = best_lws1; |
84 | |
85 | conf.dispatch_calc_stat.set_lws(tuned_lws); |
86 | } |
87 | |
88 | int get_nhwc_ic_block(int ic, int max_vect_size = 8) { |
89 | const int nblocks = ic / (max_vect_size * 16); |
90 | return nblocks < 2 || (ic / nblocks) % 16 ? ic : ic / nblocks; |
91 | } |
92 | int get_nhwc_vect_size(int ic, int simd = 16) { |
93 | int vect_size = 8; |
94 | while (true) { |
95 | if (ic / (vect_size * simd)) return vect_size; |
96 | vect_size /= 2; |
97 | } |
98 | return 1; |
99 | } |
100 | |
101 | int get_nhwc_sp_block_size( |
102 | int sp, int ic_dim, int eu_count, int threads_per_eu, int simd = 16) { |
103 | |
104 | float efficiency_thr = 0.0f; |
105 | float efficiency_peak_eu_thr = 0.0f; |
106 | int block_size_thr = 1; |
107 | int block_size_peak_eu_thr = 1; |
108 | int curr_block_size = sp; |
109 | int nthr_mul = 1; |
110 | const int ic_nsg = ic_dim / simd; // number of subgroups by ic dim |
111 | |
112 | // The search is based on threads wave efficiency. |
113 | // Higher priority for cases with peak EUs utilization. |
114 | while (nthr_mul <= 32) { |
115 | const int nthr = nthr_mul * eu_count; |
116 | curr_block_size = utils::div_up(sp * ic_nsg, nthr); |
117 | const int nblock = utils::div_up(sp, curr_block_size); |
118 | const int nthr_gen = nblock * ic_nsg; |
119 | |
120 | const float curr_efficiency_eus |
121 | = (float)nthr_gen / utils::rnd_up(nthr_gen, eu_count); |
122 | const float curr_efficiency_thr = (float)nthr_gen |
123 | / utils::rnd_up(nthr_gen, eu_count * threads_per_eu); |
124 | |
125 | if (curr_efficiency_thr > efficiency_thr) { |
126 | efficiency_thr = curr_efficiency_thr; |
127 | block_size_thr = curr_block_size; |
128 | } |
129 | if (curr_efficiency_eus == 1 |
130 | && curr_efficiency_thr > efficiency_peak_eu_thr) { |
131 | efficiency_peak_eu_thr = curr_efficiency_thr; |
132 | block_size_peak_eu_thr = curr_block_size; |
133 | } |
134 | nthr_mul++; |
135 | } |
136 | if (efficiency_peak_eu_thr > 0.0f) return block_size_peak_eu_thr; |
137 | return block_size_thr; |
138 | } |
139 | |
140 | int get_block_size(bool is_backward, int hw_threads, int nn, int ic, |
141 | int work_size, int simd = 16) { |
142 | int block_size = 256; |
143 | float thread_efficiency = 0; |
144 | int hw_thread_mult = hw_threads; |
145 | const int align_size = is_backward ? 8 : 16; |
146 | while (true) { |
147 | const int nof_blocks |
148 | = nstl::max(utils::rnd_dn(hw_thread_mult * simd, ic) / ic, 1); |
149 | const int min_block_size |
150 | = utils::rnd_up(work_size, nof_blocks) / nof_blocks; |
151 | const int curr_block_size = utils::rnd_up(min_block_size, align_size); |
152 | const int nof_blocks_generated |
153 | = utils::rnd_up(work_size, curr_block_size) / curr_block_size; |
154 | const int threads_generated = nof_blocks_generated * ic / simd; |
155 | const float curr_thread_efficiency = float(threads_generated * nn) |
156 | / float(utils::rnd_up(threads_generated * nn, hw_threads)); |
157 | if (curr_thread_efficiency > thread_efficiency) { |
158 | thread_efficiency = curr_thread_efficiency; |
159 | block_size = curr_block_size; |
160 | } |
161 | if (curr_thread_efficiency == 1.0 || curr_block_size <= 256) { break; } |
162 | hw_thread_mult += hw_threads; |
163 | } |
164 | return block_size; |
165 | } |
166 | |
167 | static status_t init_conf_common(bnorm_conf_t &conf, offsets_t &off, |
168 | const batch_normalization_pd_t *pd, engine_t *engine) { |
169 | using namespace dnnl::impl::format_tag; |
170 | |
171 | const batch_normalization_desc_t &bd = *pd->desc(); |
172 | const memory_desc_wrapper data_mdw( |
173 | pd->is_fwd() ? pd->src_md() : pd->diff_src_md()); |
174 | const int ndims = data_mdw.ndims(); |
175 | |
176 | conf.data_type = data_mdw.data_type(); |
177 | |
178 | conf.ndims = ndims; |
179 | conf.mb = data_mdw.dims()[0]; |
180 | |
181 | conf.ic = data_mdw.dims()[1]; |
182 | conf.id = (ndims == 5) ? data_mdw.dims()[2] : 1; |
183 | conf.ih = (ndims == 3) ? 1 : data_mdw.dims()[ndims - 2]; |
184 | conf.iw = data_mdw.dims()[ndims - 1]; |
185 | |
186 | conf.is_forward = pd->is_fwd(); |
187 | conf.is_backward = !pd->is_fwd(); |
188 | |
189 | conf.use_scale = pd->use_scale(); |
190 | conf.use_shift = pd->use_shift(); |
191 | conf.save_stats = pd->is_training(); |
192 | conf.is_training = pd->is_training(); |
193 | conf.fuse_norm_add_relu = pd->fuse_norm_add_relu(); |
194 | conf.fuse_norm_relu = pd->fuse_norm_relu() || pd->fuse_norm_add_relu(); |
195 | conf.calculate_stats = !pd->stats_is_src(); |
196 | conf.with_relu = pd->with_relu_post_op(); |
197 | conf.eps = bd.batch_norm_epsilon; |
198 | conf.calculate_diff_stats = !pd->use_global_stats(); |
199 | conf.diff_scale = (pd->use_scale() && bd.prop_kind == prop_kind::backward); |
200 | conf.diff_shift = (pd->use_shift() && bd.prop_kind == prop_kind::backward); |
201 | |
202 | set_offsets(data_mdw, off.src_off); |
203 | |
204 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
205 | auto gpu_arch = compute_engine->device_info()->gpu_arch(); |
206 | |
207 | conf.mb_block = 1; |
208 | |
209 | const bool has_padding = !data_mdw.is_dense(); |
210 | const bool is_blocked_16c |
211 | = data_mdw.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c); |
212 | const bool is_blocked_16n16c |
213 | = data_mdw.matches_one_of_tag(NCw16n16c, NChw16n16c, NCdhw16n16c); |
214 | const bool is_blocked_32n16c |
215 | = data_mdw.matches_one_of_tag(NCw32n16c, NChw32n16c, NCdhw32n16c); |
216 | const bool is_nhwc |
217 | = conf.ic % 8 == 0 && data_mdw.matches_one_of_tag(nwc, nhwc, ndhwc); |
218 | |
219 | // Due to intel_sub_group_write_uc requires 16-bytes alignment, |
220 | // IC div by 8 tail processing is not applicable to fuse_norm_relu |
221 | // and char data type. |
222 | if (conf.ic % 8 == 0 && conf.ic % 16 |
223 | && (conf.fuse_norm_relu || conf.data_type == data_type::s8)) |
224 | return status::unimplemented; |
225 | // IC tail processing performnce boost is not obvious on arch < xe_hpc |
226 | if (conf.ic % 8 == 0 && conf.ic % 16 |
227 | && gpu_arch < compute::gpu_arch_t::xe_hpc) |
228 | return status::unimplemented; |
229 | |
230 | conf.use_stats_one_pass = experimental::use_bnorm_stats_one_pass(); |
231 | |
232 | // IC tail processing is not implemented yet for one pass algorithm |
233 | // TODO: implement it, possible perf boost could be ~ 2x |
234 | if (conf.ic % 8 == 0 && conf.ic % 16 && conf.use_stats_one_pass) |
235 | conf.use_stats_one_pass = false; |
236 | |
237 | conf.use_nhwc = is_nhwc; |
238 | |
239 | if (has_padding |
240 | || !(is_blocked_16c || is_blocked_16n16c || is_blocked_32n16c |
241 | || is_nhwc)) |
242 | return status::unimplemented; |
243 | |
244 | // IC tail processing is not implemented yet for NHWC optimized kernels |
245 | // TODO: implement it |
246 | // The use of NHWC optimized kernels |
247 | // is limited by XeHPG+ due to performance reasons |
248 | conf.nhwc_optimized = conf.ic % 16 == 0 |
249 | && data_mdw.matches_one_of_tag(nwc, nhwc, ndhwc) |
250 | && gpu_arch >= compute::gpu_arch_t::xe_hpg; |
251 | |
252 | conf.use_fused_atomics_reduction |
253 | = use_fused_atomics_reduction(conf, engine); |
254 | |
255 | if (conf.nhwc_optimized) { |
256 | conf.ic_block = get_nhwc_ic_block(utils::rnd_up(conf.ic, 16)); |
257 | } else { |
258 | conf.ic_block = 16; |
259 | } |
260 | |
261 | conf.mb_block = is_blocked_32n16c ? 32 : is_blocked_16n16c ? 16 : 1; |
262 | |
263 | if (is_nhwc) { |
264 | // reshape to xc |
265 | conf.nn = 1; |
266 | conf.sp = conf.mb * conf.id * conf.ih * conf.iw; |
267 | } else { |
268 | // reshape to nCx16c |
269 | conf.nn = conf.mb / conf.mb_block; |
270 | conf.sp = conf.id * conf.ih * conf.iw * conf.mb_block; |
271 | } |
272 | |
273 | // The case IC==8 requires spacial dim to be even because of using one |
274 | // block read/write operation for 2 spacial rows at once |
275 | if (is_nhwc && conf.ic == 8 && conf.sp % 2) return status::unimplemented; |
276 | |
277 | conf.calc_stat_ic = conf.nhwc_optimized |
278 | ? utils::div_up(conf.ic, conf.ic_block) * 16 |
279 | : utils::rnd_up(conf.ic, 16); |
280 | |
281 | auto eu_count = compute_engine->device_info()->eu_count(); |
282 | auto threads_per_eu |
283 | = compute::device_info_t::threads_per_eu(gpu_arch, false); |
284 | const int max_sp_block_size = get_block_size(conf.is_backward, eu_count, |
285 | conf.nn, utils::rnd_up(conf.ic, 16), conf.sp); |
286 | |
287 | if (conf.nn == 1) |
288 | conf.stat_sp_block = conf.nhwc_optimized |
289 | ? get_nhwc_sp_block_size( |
290 | conf.sp, conf.calc_stat_ic, eu_count, threads_per_eu) |
291 | : max_sp_block_size; |
292 | else |
293 | conf.stat_sp_block |
294 | = nstl::min(utils::rnd_up(conf.sp, 16), max_sp_block_size); |
295 | |
296 | conf.stat_sp_nblocks |
297 | = utils::rnd_up(conf.sp, conf.stat_sp_block) / conf.stat_sp_block; |
298 | conf.stat_sp_tail |
299 | = utils::rnd_dn(conf.sp, conf.stat_sp_block) / conf.stat_sp_block; |
300 | |
301 | conf.reduce_stat_nblocks = conf.nn * conf.stat_sp_nblocks; |
302 | |
303 | conf.vect_size |
304 | = conf.nhwc_optimized ? get_nhwc_vect_size(conf.ic_block) : 8; |
305 | |
306 | conf.dispatch_calc_stat = compute_engine->create_dispatch(); |
307 | conf.dispatch_calc_stat.define_dim("STAT_MB" , 0, conf.nn); |
308 | conf.dispatch_calc_stat.define_dim("STAT_SP" , 1, conf.stat_sp_nblocks); |
309 | conf.dispatch_calc_stat.define_dim_with_nesting_level( |
310 | "STAT_IC" , 1, conf.calc_stat_ic); |
311 | CHECK(conf.dispatch_calc_stat.vectorize_dim("STAT_IC" , 16)); |
312 | conf.dispatch_calc_stat.set_kernel_attr_suffix("CALC" ); |
313 | conf.dispatch_calc_stat.generate(); |
314 | if (conf.use_fused_atomics_reduction) |
315 | adjust_lws_calc_kernel(conf, compute_engine); |
316 | |
317 | conf.dispatch_reduce_stat = compute_engine->create_dispatch(); |
318 | int reduce_sub_group_count = 1; |
319 | while (conf.reduce_stat_nblocks % (2 * reduce_sub_group_count) == 0 |
320 | && 2 * reduce_sub_group_count * 16 <= 256) { |
321 | reduce_sub_group_count = reduce_sub_group_count * 2; |
322 | } |
323 | conf.stat_ic = reduce_sub_group_count * 16; |
324 | conf.dispatch_reduce_stat.define_dim("REDUCE_STAT_IC" , 0, conf.stat_ic); |
325 | conf.dispatch_reduce_stat.define_dim( |
326 | "REDUCE_IC_GROUP" , 1, utils::rnd_up(conf.ic, 16) / 16); |
327 | CHECK(conf.dispatch_reduce_stat.vectorize_dim("REDUCE_STAT_IC" , 16)); |
328 | conf.dispatch_reduce_stat.set_kernel_attr_suffix("REDUCE" ); |
329 | conf.dispatch_reduce_stat.generate(); |
330 | |
331 | const int sp_pad = utils::rnd_up(conf.sp, conf.vect_size); |
332 | conf.sp_tail = utils::rnd_dn(conf.sp, conf.vect_size); |
333 | |
334 | conf.dispatch = compute_engine->create_dispatch(data_mdw.md_); |
335 | conf.dispatch.define_dim("MB" , 0, conf.nn); |
336 | conf.dispatch.define_dim("SP" , 1, |
337 | conf.nhwc_optimized ? conf.stat_sp_nblocks |
338 | : sp_pad / conf.vect_size); |
339 | conf.dispatch.define_dim("IC" , 2, conf.calc_stat_ic); |
340 | |
341 | CHECK(conf.dispatch.vectorize_dim("IC" , 16)); |
342 | conf.dispatch.generate(); |
343 | |
344 | conf.dispatch_reduce_aux = compute_engine->create_dispatch(data_mdw.md_); |
345 | conf.dispatch_reduce_aux.define_dim("IC_AUX" , 0, conf.ic); |
346 | conf.dispatch_reduce_aux.set_kernel_attr_suffix("AUX" ); |
347 | conf.dispatch_reduce_aux.generate(); |
348 | |
349 | return status::success; |
350 | } |
351 | |
352 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
353 | const bnorm_conf_t &conf, const offsets_t &off) { |
354 | kernel_ctx.set_data_type(conf.data_type); |
355 | |
356 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
357 | kernel_ctx.define_int("MB" , conf.mb); |
358 | kernel_ctx.define_int("IC" , conf.ic); |
359 | kernel_ctx.define_int("IC16" , utils::rnd_up(conf.ic, 16)); |
360 | kernel_ctx.define_int("ID" , conf.id); |
361 | kernel_ctx.define_int("IH" , conf.ih); |
362 | kernel_ctx.define_int("IW" , conf.iw); |
363 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
364 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
365 | |
366 | kernel_ctx.define_int("USE_NHWC" , conf.use_nhwc); |
367 | kernel_ctx.define_int("SP" , conf.sp); |
368 | kernel_ctx.define_int("SP_TAIL" , conf.sp_tail); |
369 | kernel_ctx.define_int("VECT_SIZE" , conf.vect_size); |
370 | |
371 | kernel_ctx.define_int("STAT_SP_BLOCK" , conf.stat_sp_block); |
372 | kernel_ctx.define_int("STAT_SP_NBLOCKS" , conf.stat_sp_nblocks); |
373 | kernel_ctx.define_int("STAT_SP_TAIL" , conf.stat_sp_tail); |
374 | kernel_ctx.define_int("REDUCE_STAT_NBLOCKS" , conf.reduce_stat_nblocks); |
375 | |
376 | if (conf.is_forward) |
377 | kernel_ctx.define_int("IS_FWD" , 1); |
378 | else if (conf.is_backward) |
379 | kernel_ctx.define_int("IS_BWD" , 1); |
380 | |
381 | kernel_ctx.define_int("WITH_RELU" , conf.with_relu); |
382 | kernel_ctx.define_int("SAVE_STATS" , conf.save_stats); |
383 | kernel_ctx.define_int("IS_TRAINING" , conf.is_training); |
384 | kernel_ctx.define_int("FUSE_BN_RELU" , conf.fuse_norm_relu); |
385 | kernel_ctx.define_int("FUSE_BN_ADD_RELU" , conf.fuse_norm_add_relu); |
386 | kernel_ctx.define_int("CALCULATE_STATS" , conf.calculate_stats); |
387 | kernel_ctx.define_int("USE_SCALE" , conf.use_scale); |
388 | kernel_ctx.define_int("USE_SHIFT" , conf.use_shift); |
389 | kernel_ctx.define_int("CALCULATE_DIFF_STATS" , conf.calculate_diff_stats); |
390 | kernel_ctx.define_int("DIFF_SCALE" , conf.diff_scale); |
391 | kernel_ctx.define_int("DIFF_SHIFT" , conf.diff_shift); |
392 | kernel_ctx.define_int("REDUCE_IC_SUB_GROUPS" , conf.stat_ic / 16); |
393 | kernel_ctx.define_int("USE_STATS_ONE_PASS" , conf.use_stats_one_pass); |
394 | kernel_ctx.define_int("NHWC_OPTIMIZED" , conf.nhwc_optimized); |
395 | kernel_ctx.define_int( |
396 | "FUSED_ATOMICS_REDUCTION" , conf.use_fused_atomics_reduction); |
397 | |
398 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
399 | if (conf.data_type == data_type::s8) |
400 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
401 | |
402 | def_offsets(off.src_off, kernel_ctx, "SRC" , conf.ndims); |
403 | |
404 | def_dispatch(kernel_ctx, conf.dispatch_calc_stat); |
405 | def_dispatch(kernel_ctx, conf.dispatch_reduce_stat); |
406 | def_dispatch(kernel_ctx, conf.dispatch_reduce_aux); |
407 | def_dispatch(kernel_ctx, conf.dispatch); |
408 | |
409 | return status::success; |
410 | } |
411 | |
412 | status_t gen9_batch_normalization_fwd_t::pd_t::init_conf(engine_t *engine) { |
413 | return init_conf_common(conf, off, this, engine); |
414 | } |
415 | |
416 | status_t gen9_batch_normalization_fwd_t::pd_t::init_kernel_ctx( |
417 | compute::kernel_ctx_t &kernel_ctx) const { |
418 | return init_kernel_ctx_common(kernel_ctx, conf, off); |
419 | } |
420 | |
421 | void gen9_batch_normalization_fwd_t::pd_t::init_scratchpad() { |
422 | if (conf.calculate_stats) { |
423 | size_t size_coeff = sizeof(double) / sizeof(float); |
424 | size_t size = 2 * size_coeff * conf.reduce_stat_nblocks |
425 | * utils::rnd_up(conf.ic, 16); |
426 | |
427 | auto scratchpad = scratchpad_registry().registrar(); |
428 | scratchpad.book(key_bnorm_reduction, size, |
429 | types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT); |
430 | if (!conf.save_stats) { |
431 | scratchpad.book(key_bnorm_tmp_mean, conf.ic, |
432 | types::data_type_size(data_type::f32), |
433 | OCL_BUFFER_ALIGNMENT); |
434 | scratchpad.book(key_bnorm_tmp_var, conf.ic, |
435 | types::data_type_size(data_type::f32), |
436 | OCL_BUFFER_ALIGNMENT); |
437 | } |
438 | } |
439 | } |
440 | |
441 | status_t gen9_batch_normalization_fwd_t::execute_forward( |
442 | const exec_ctx_t &ctx) const { |
443 | |
444 | status_t status = status::success; |
445 | const auto &conf = pd()->conf; |
446 | |
447 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
448 | auto &src_add = CTX_IN_STORAGE(DNNL_ARG_SRC_1); |
449 | |
450 | auto &mean_ = pd()->stats_is_src() |
451 | ? CTX_IN_STORAGE(DNNL_ARG_MEAN) |
452 | : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_MEAN, status); |
453 | CHECK(status); |
454 | |
455 | auto &variance_ = pd()->stats_is_src() |
456 | ? CTX_IN_STORAGE(DNNL_ARG_VARIANCE) |
457 | : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_VARIANCE, status); |
458 | CHECK(status); |
459 | |
460 | auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); |
461 | auto &shift = CTX_IN_STORAGE(DNNL_ARG_SHIFT); |
462 | |
463 | auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status); |
464 | CHECK(status); |
465 | auto &ws = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_WORKSPACE, status); |
466 | CHECK(status); |
467 | |
468 | std::unique_ptr<memory_storage_t> temp_reduce; |
469 | std::unique_ptr<memory_storage_t> tmp_mean; |
470 | std::unique_ptr<memory_storage_t> tmp_variance; |
471 | if (conf.calculate_stats) { |
472 | temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage( |
473 | key_bnorm_reduction); |
474 | |
475 | if (!conf.save_stats) { |
476 | tmp_mean = ctx.get_scratchpad_grantor().get_memory_storage( |
477 | key_bnorm_tmp_mean); |
478 | tmp_variance = ctx.get_scratchpad_grantor().get_memory_storage( |
479 | key_bnorm_tmp_var); |
480 | } |
481 | } |
482 | |
483 | auto &mean = (conf.calculate_stats && !conf.save_stats) ? *tmp_mean : mean_; |
484 | auto &variance = (conf.calculate_stats && !conf.save_stats) ? *tmp_variance |
485 | : variance_; |
486 | |
487 | if (conf.calculate_stats && conf.use_fused_atomics_reduction) { |
488 | // Atomics-based reduction requires zeroing mean and variance |
489 | compute::kernel_arg_list_t arg_list; |
490 | arg_list.set(0, mean); |
491 | arg_list.set(1, variance); |
492 | |
493 | auto nd_range = conf.dispatch_reduce_aux.nd_range(); |
494 | status = parallel_for(ctx, nd_range, reduce_init_kernel_, arg_list); |
495 | if (status != status::success) return status; |
496 | } |
497 | |
498 | if (conf.calculate_stats && !conf.use_stats_one_pass) { |
499 | compute::kernel_arg_list_t calc_mean_arg_list; |
500 | calc_mean_arg_list.set(0, src); |
501 | calc_mean_arg_list.set(1, *temp_reduce); |
502 | calc_mean_arg_list.set(2, mean); |
503 | |
504 | auto nd_range_calc_mean = conf.dispatch_calc_stat.nd_range(); |
505 | |
506 | status = parallel_for(ctx, nd_range_calc_mean, calculate_mean_kernel_, |
507 | calc_mean_arg_list); |
508 | if (status != status::success) return status; |
509 | |
510 | if (conf.use_fused_atomics_reduction) { |
511 | compute::kernel_arg_list_t arg_list; |
512 | arg_list.set(0, mean); |
513 | auto nd_range = conf.dispatch_reduce_aux.nd_range(); |
514 | status = parallel_for( |
515 | ctx, nd_range, reduce_final_kernel_, arg_list); |
516 | if (status != status::success) return status; |
517 | } else { |
518 | compute::kernel_arg_list_t reduce_mean_arg_list; |
519 | reduce_mean_arg_list.set(0, *temp_reduce); |
520 | reduce_mean_arg_list.set(1, mean); |
521 | |
522 | auto nd_range_reduce_mean = conf.dispatch_reduce_stat.nd_range(); |
523 | |
524 | status = parallel_for(ctx, nd_range_reduce_mean, |
525 | reduce_mean_kernel_, reduce_mean_arg_list); |
526 | if (status != status::success) return status; |
527 | } |
528 | |
529 | compute::kernel_arg_list_t calc_var_arg_list; |
530 | calc_var_arg_list.set(0, src); |
531 | calc_var_arg_list.set(1, mean); |
532 | calc_var_arg_list.set(2, *temp_reduce); |
533 | calc_var_arg_list.set(3, variance); |
534 | |
535 | auto nd_range_calc_var = conf.dispatch_calc_stat.nd_range(); |
536 | |
537 | status = parallel_for(ctx, nd_range_calc_var, |
538 | calculate_variance_kernel_, calc_var_arg_list); |
539 | if (status != status::success) return status; |
540 | |
541 | if (conf.use_fused_atomics_reduction) { |
542 | compute::kernel_arg_list_t arg_list; |
543 | arg_list.set(0, variance); |
544 | auto nd_range = conf.dispatch_reduce_aux.nd_range(); |
545 | status = parallel_for( |
546 | ctx, nd_range, reduce_final_kernel_, arg_list); |
547 | if (status != status::success) return status; |
548 | } else { |
549 | compute::kernel_arg_list_t reduce_var_arg_list; |
550 | reduce_var_arg_list.set(0, *temp_reduce); |
551 | reduce_var_arg_list.set(1, variance); |
552 | |
553 | auto nd_range_reduce_var = conf.dispatch_reduce_stat.nd_range(); |
554 | |
555 | status = parallel_for(ctx, nd_range_reduce_var, |
556 | reduce_variance_kernel_, reduce_var_arg_list); |
557 | if (status != status::success) return status; |
558 | } |
559 | } |
560 | if (conf.calculate_stats && conf.use_stats_one_pass) { |
561 | compute::kernel_arg_list_t calc_mean_var_arg_list; |
562 | calc_mean_var_arg_list.set(0, src); |
563 | calc_mean_var_arg_list.set(1, *temp_reduce); |
564 | calc_mean_var_arg_list.set(2, mean); |
565 | calc_mean_var_arg_list.set(3, variance); |
566 | |
567 | auto nd_range_calc_mean = conf.dispatch_calc_stat.nd_range(); |
568 | |
569 | status = parallel_for(ctx, nd_range_calc_mean, |
570 | calculate_mean_var_kernel_, calc_mean_var_arg_list); |
571 | if (status != status::success) return status; |
572 | |
573 | if (conf.use_fused_atomics_reduction) { |
574 | compute::kernel_arg_list_t reduce_final_arg_list; |
575 | reduce_final_arg_list.set(0, mean); |
576 | reduce_final_arg_list.set(1, variance); |
577 | auto nd_range_reduce_final = conf.dispatch_reduce_aux.nd_range(); |
578 | |
579 | status = parallel_for(ctx, nd_range_reduce_final, |
580 | reduce_final_kernel_, reduce_final_arg_list); |
581 | if (status != status::success) return status; |
582 | } else { |
583 | compute::kernel_arg_list_t reduce_mean_var_arg_list; |
584 | reduce_mean_var_arg_list.set(0, *temp_reduce); |
585 | reduce_mean_var_arg_list.set(1, mean); |
586 | reduce_mean_var_arg_list.set(2, variance); |
587 | |
588 | auto nd_range_reduce_mean = conf.dispatch_reduce_stat.nd_range(); |
589 | |
590 | status = parallel_for(ctx, nd_range_reduce_mean, |
591 | reduce_mean_var_kernel_, reduce_mean_var_arg_list); |
592 | if (status != status::success) return status; |
593 | } |
594 | } |
595 | compute::kernel_arg_list_t arg_list; |
596 | arg_list.set(0, src); |
597 | arg_list.set(1, mean); |
598 | arg_list.set(2, variance); |
599 | arg_list.set(3, dst); |
600 | arg_list.set(4, scale); |
601 | arg_list.set(5, shift); |
602 | arg_list.set(6, ws); |
603 | arg_list.set(7, conf.eps); |
604 | arg_list.set(8, src_add); |
605 | |
606 | auto nd_range = conf.dispatch.nd_range(); |
607 | |
608 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
609 | return status; |
610 | } |
611 | |
612 | status_t gen9_batch_normalization_bwd_t::pd_t::init_conf(engine_t *engine) { |
613 | return init_conf_common(conf, off, this, engine); |
614 | } |
615 | |
616 | status_t gen9_batch_normalization_bwd_t::pd_t::init_kernel_ctx( |
617 | compute::kernel_ctx_t &kernel_ctx) const { |
618 | return init_kernel_ctx_common(kernel_ctx, conf, off); |
619 | } |
620 | |
621 | void gen9_batch_normalization_bwd_t::pd_t::init_scratchpad() { |
622 | size_t size |
623 | = 2 * utils::rnd_up(conf.ic, 16) * (1 + conf.reduce_stat_nblocks); |
624 | auto scratchpad = scratchpad_registry().registrar(); |
625 | scratchpad.book(key_bnorm_reduction, size, |
626 | types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT); |
627 | } |
628 | |
629 | status_t gen9_batch_normalization_bwd_t::execute_backward( |
630 | const exec_ctx_t &ctx) const { |
631 | |
632 | status_t status = status::success; |
633 | |
634 | const auto &conf = pd()->conf; |
635 | |
636 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
637 | auto &mean = CTX_IN_STORAGE(DNNL_ARG_MEAN); |
638 | auto &variance = CTX_IN_STORAGE(DNNL_ARG_VARIANCE); |
639 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
640 | auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); |
641 | auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); |
642 | |
643 | auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status); |
644 | CHECK(status); |
645 | auto &diff_src_add = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC_1, status); |
646 | CHECK(status); |
647 | |
648 | auto &diff_scale_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SCALE, status); |
649 | CHECK(status); |
650 | auto &diff_shift_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SHIFT, status); |
651 | CHECK(status); |
652 | |
653 | std::unique_ptr<memory_storage_t> temp_reduce; |
654 | temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage( |
655 | key_bnorm_reduction); |
656 | |
657 | auto &diff_scale = !conf.diff_scale ? *temp_reduce : diff_scale_; |
658 | auto &diff_shift = !conf.diff_shift ? *temp_reduce : diff_shift_; |
659 | |
660 | if (conf.use_fused_atomics_reduction) { |
661 | compute::kernel_arg_list_t reduce_init_arg_list; |
662 | reduce_init_arg_list.set(0, diff_scale); |
663 | reduce_init_arg_list.set(1, diff_shift); |
664 | |
665 | auto nd_range_reduce_init = conf.dispatch_reduce_aux.nd_range(); |
666 | status = parallel_for(ctx, nd_range_reduce_init, reduce_init_kernel_, |
667 | reduce_init_arg_list); |
668 | if (status != status::success) return status; |
669 | } |
670 | |
671 | compute::kernel_arg_list_t calc_stats_arg_list; |
672 | calc_stats_arg_list.set(0, src); |
673 | calc_stats_arg_list.set(1, mean); |
674 | calc_stats_arg_list.set(2, diff_dst); |
675 | calc_stats_arg_list.set(3, ws); |
676 | calc_stats_arg_list.set(4, *temp_reduce); |
677 | calc_stats_arg_list.set(5, diff_scale); |
678 | calc_stats_arg_list.set(6, diff_shift); |
679 | |
680 | auto nd_range = conf.dispatch_calc_stat.nd_range(); |
681 | status = parallel_for( |
682 | ctx, nd_range, calculate_stats_kernel_, calc_stats_arg_list); |
683 | if (status != status::success) return status; |
684 | |
685 | if (conf.use_fused_atomics_reduction) { |
686 | compute::kernel_arg_list_t arg_list; |
687 | arg_list.set(0, diff_scale); |
688 | arg_list.set(1, variance); |
689 | arg_list.set(2, conf.eps); |
690 | auto nd_range = conf.dispatch_reduce_aux.nd_range(); |
691 | status = parallel_for(ctx, nd_range, reduce_final_kernel_, arg_list); |
692 | if (status != status::success) return status; |
693 | } else { |
694 | compute::kernel_arg_list_t reduce_stats_arg_list; |
695 | reduce_stats_arg_list.set(0, *temp_reduce); |
696 | reduce_stats_arg_list.set(1, diff_scale); |
697 | reduce_stats_arg_list.set(2, diff_shift); |
698 | reduce_stats_arg_list.set(3, variance); |
699 | reduce_stats_arg_list.set(4, conf.eps); |
700 | |
701 | auto nd_range_reduce_stat = conf.dispatch_reduce_stat.nd_range(); |
702 | status = parallel_for(ctx, nd_range_reduce_stat, reduce_stats_kernel_, |
703 | reduce_stats_arg_list); |
704 | if (status != status::success) return status; |
705 | } |
706 | |
707 | compute::kernel_arg_list_t arg_list; |
708 | arg_list.set(0, src); |
709 | arg_list.set(1, mean); |
710 | arg_list.set(2, variance); |
711 | arg_list.set(3, diff_dst); |
712 | arg_list.set(4, scale); |
713 | arg_list.set(5, ws); |
714 | arg_list.set(6, diff_src); |
715 | arg_list.set(7, diff_scale); |
716 | arg_list.set(8, diff_shift); |
717 | arg_list.set(9, conf.eps); |
718 | arg_list.set(10, diff_src_add); |
719 | |
720 | nd_range = conf.dispatch.nd_range(); |
721 | status = parallel_for(ctx, nd_range, bwd_kernel_, arg_list); |
722 | |
723 | return status; |
724 | } |
725 | |
726 | } // namespace ocl |
727 | } // namespace gpu |
728 | } // namespace impl |
729 | } // namespace dnnl |
730 | |