1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_batch_normalization.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/scratchpad.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | using namespace dnnl::impl::memory_tracking::names; |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace ocl { |
31 | |
32 | static status_t init_conf_common(bnorm_conf_t &conf, offsets_t &off, |
33 | const batch_normalization_pd_t *pd, engine_t *engine) { |
34 | using namespace dnnl::impl::format_tag; |
35 | |
36 | const batch_normalization_desc_t &bd = *pd->desc(); |
37 | const memory_desc_wrapper data_mdw( |
38 | pd->is_fwd() ? pd->src_md() : pd->diff_src_md()); |
39 | const int ndims = data_mdw.ndims(); |
40 | |
41 | conf = utils::zero<decltype(conf)>(); |
42 | conf.data_type = data_mdw.data_type(); |
43 | |
44 | conf.ndims = ndims; |
45 | conf.mb = data_mdw.dims()[0]; |
46 | |
47 | conf.ic = data_mdw.dims()[1]; |
48 | conf.id = (ndims >= 5) ? data_mdw.dims()[ndims - 3] : 1; |
49 | conf.ih = (ndims >= 4) ? data_mdw.dims()[ndims - 2] : 1; |
50 | conf.iw = (ndims >= 3) ? data_mdw.dims()[ndims - 1] : 1; |
51 | |
52 | conf.is_forward = pd->is_fwd(); |
53 | conf.is_backward = !pd->is_fwd(); |
54 | |
55 | conf.use_scale = pd->use_scale(); |
56 | conf.use_shift = pd->use_shift(); |
57 | conf.save_stats = pd->is_training(); |
58 | conf.is_training = pd->is_training(); |
59 | conf.fuse_norm_relu = pd->fuse_norm_relu() || pd->fuse_norm_add_relu(); |
60 | conf.fuse_norm_add_relu = pd->fuse_norm_add_relu(); |
61 | conf.calculate_stats = !pd->stats_is_src(); |
62 | conf.with_relu = pd->with_relu_post_op(); |
63 | conf.eps = bd.batch_norm_epsilon; |
64 | conf.calculate_diff_stats = !pd->use_global_stats(); |
65 | conf.diff_scale = (pd->use_scale() && bd.prop_kind == prop_kind::backward); |
66 | conf.diff_shift = (pd->use_shift() && bd.prop_kind == prop_kind::backward); |
67 | |
68 | set_offsets(data_mdw, off.src_off); |
69 | |
70 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
71 | |
72 | conf.use_16mb_unroll = false; |
73 | conf.use_nhwc = false; |
74 | conf.mb_block = 1; |
75 | conf.ic_block = 1; |
76 | |
77 | const bool has_padding = !data_mdw.is_dense(); |
78 | |
79 | if (!has_padding && conf.is_backward |
80 | && data_mdw.matches_one_of_tag(nCw16c, nChw16c, nCdhw16c, NCw16n16c, |
81 | NChw16n16c, NCdhw16n16c)) { |
82 | conf.mb_block = data_mdw.matches_one_of_tag( |
83 | NCw16n16c, NChw16n16c, NCdhw16n16c) |
84 | ? 16 |
85 | : 1; |
86 | conf.ic_block = 16; |
87 | conf.use_16mb_unroll = true; |
88 | |
89 | const int max_stat_nblocks = 256; |
90 | int stat_mb_nblocks = conf.mb / conf.mb_block; |
91 | int stat_sp_nblocks = utils::max_div(conf.id * conf.ih * conf.iw, |
92 | nstl::max(1, max_stat_nblocks / stat_mb_nblocks)); |
93 | assert(stat_mb_nblocks * stat_sp_nblocks <= max_stat_nblocks); |
94 | |
95 | int stat_sp_block = conf.id * conf.ih * conf.iw / stat_sp_nblocks; |
96 | |
97 | conf.reduce_stat_nblocks = stat_mb_nblocks * stat_sp_nblocks; |
98 | |
99 | conf.dispatch_calc_stat = compute_engine->create_dispatch(); |
100 | conf.dispatch_calc_stat.define_dim_with_nesting_level( |
101 | "STAT_SP" , 2, conf.id * conf.ih * conf.iw, stat_sp_block); |
102 | conf.dispatch_calc_stat.define_dim_with_nesting_level( |
103 | "STAT_IC" , 1, conf.ic); |
104 | conf.dispatch_calc_stat.define_dim_with_nesting_level( |
105 | "STAT_MB" , 0, conf.mb, conf.mb_block); |
106 | CHECK(conf.dispatch_calc_stat.vectorize_dim("STAT_IC" , 16)); |
107 | conf.dispatch_calc_stat.set_kernel_attr_suffix("CALC" ); |
108 | conf.dispatch_calc_stat.generate(); |
109 | |
110 | conf.dispatch_reduce_stat = compute_engine->create_dispatch(); |
111 | conf.dispatch_reduce_stat.define_dim("REDUCE_STAT_IC" , conf.ic); |
112 | conf.dispatch_reduce_stat.set_kernel_attr_suffix("REDUCE" ); |
113 | conf.dispatch_reduce_stat.generate(); |
114 | |
115 | conf.dispatch = compute_engine->create_dispatch(data_mdw.md_); |
116 | conf.dispatch.define_dim("MB" , 0, conf.mb, conf.mb_block); |
117 | conf.dispatch.define_dim("IC" , 1, conf.ic); |
118 | conf.dispatch.define_dim("ID" , nstl::max(1, ndims - 3), conf.id); |
119 | conf.dispatch.define_dim("IH" , nstl::max(1, ndims - 2), conf.ih); |
120 | conf.dispatch.define_dim("IW" , nstl::max(1, ndims - 1), conf.iw); |
121 | CHECK(conf.dispatch.vectorize_dim("IC" , 16)); |
122 | conf.dispatch.generate(); |
123 | } else { |
124 | // Reference |
125 | conf.use_16mb_unroll = false; |
126 | conf.dispatch = compute_engine->create_dispatch(data_mdw.md_); |
127 | conf.dispatch.define_dim("MB" , 0, conf.mb); |
128 | conf.dispatch.define_dim("IC" , 1, conf.ic); |
129 | conf.dispatch.define_dim("ID" , nstl::max(1, ndims - 3), conf.id); |
130 | conf.dispatch.define_dim("IH" , nstl::max(1, ndims - 2), conf.ih); |
131 | conf.dispatch.define_dim("IW" , nstl::max(1, ndims - 1), conf.iw); |
132 | |
133 | conf.dispatch.generate(); |
134 | if (conf.calculate_stats || conf.is_backward) { |
135 | |
136 | conf.dispatch_calc_stat |
137 | = compute_engine->create_dispatch(data_mdw.md_); |
138 | int calc_dims[5]; |
139 | auto &dims = data_mdw.dims(); |
140 | calc_dims[0] = dims[0]; |
141 | calc_dims[1] = dims[1]; |
142 | calc_dims[2] = (ndims < 5) ? 1 : dims[ndims - 3]; |
143 | calc_dims[3] = (ndims < 4) ? 1 : dims[ndims - 2]; |
144 | calc_dims[4] = (ndims < 3) ? 1 : dims[ndims - 1]; |
145 | int reduce_dim_idx = 0; |
146 | for (int i = 2; i < 5; i++) { |
147 | if (calc_dims[i] > calc_dims[reduce_dim_idx]) { |
148 | reduce_dim_idx = i; |
149 | } |
150 | } |
151 | conf.reduce_dim = calc_dims[reduce_dim_idx]; |
152 | conf.reduce_dim_idx = reduce_dim_idx; |
153 | const std::string dim_names[5] |
154 | = {"STAT_MB" , "STAT_IC" , "STAT_ID" , "STAT_IH" , "STAT_IW" }; |
155 | const std::string &reduce_dim_name = dim_names[reduce_dim_idx]; |
156 | |
157 | conf.vectorize_calc_stats = false; |
158 | conf.vect_size = 1; |
159 | conf.sub_group_size = 1; |
160 | int calc_dims_blocks[5] = {1, 1, 1, 1, 1}; |
161 | |
162 | // Translate reduce_dim_idx from being an index in calc_dims to dims array |
163 | const int base_reduce_dim_idx |
164 | = reduce_dim_idx == 0 ? 0 : reduce_dim_idx - (5 - ndims); |
165 | const int reduce_dim_stride |
166 | = data_mdw.blocking_desc().strides[base_reduce_dim_idx]; |
167 | if (conf.is_forward && conf.reduce_dim % 16 == 0 |
168 | && reduce_dim_stride == 1) { |
169 | // Calculations over reduce dimension will be splitted |
170 | // between work items in the single subgroup. |
171 | // Each item will read vector_size of elements at once. |
172 | conf.vectorize_calc_stats = true; |
173 | conf.sub_group_size = 16; |
174 | |
175 | int vector_size = 8; |
176 | while (conf.reduce_dim % (conf.sub_group_size * vector_size) |
177 | != 0) { |
178 | vector_size /= 2; |
179 | } |
180 | conf.vect_size = vector_size; |
181 | calc_dims_blocks[reduce_dim_idx] |
182 | = conf.reduce_dim / conf.sub_group_size; |
183 | } else { |
184 | // Whole reduce dimension will be handled by single work item. |
185 | calc_dims[reduce_dim_idx] = 1; |
186 | } |
187 | |
188 | conf.stat_ic = utils::array_product(calc_dims, 5); |
189 | conf.dispatch_calc_stat.define_dim( |
190 | dim_names[0], 0, calc_dims[0], calc_dims_blocks[0]); |
191 | conf.dispatch_calc_stat.define_dim( |
192 | dim_names[1], 1, calc_dims[1], calc_dims_blocks[1]); |
193 | conf.dispatch_calc_stat.define_dim(dim_names[2], |
194 | nstl::max(1, ndims - 3), calc_dims[2], calc_dims_blocks[2]); |
195 | conf.dispatch_calc_stat.define_dim(dim_names[3], |
196 | nstl::max(1, ndims - 2), calc_dims[3], calc_dims_blocks[3]); |
197 | conf.dispatch_calc_stat.define_dim(dim_names[4], |
198 | nstl::max(1, ndims - 1), calc_dims[4], calc_dims_blocks[4]); |
199 | |
200 | conf.skip_reduce_stat = false; |
201 | if (conf.vectorize_calc_stats) { |
202 | CHECK(conf.dispatch_calc_stat.vectorize_dim( |
203 | reduce_dim_name, conf.sub_group_size)); |
204 | if (conf.stat_ic == conf.reduce_dim * calc_dims[1]) { |
205 | // if there are only 2 dimensions greater than 1: |
206 | // IC and reduce_dim, calc phase of batchnorm will do |
207 | // whole reduction and reduce phase can be skipped |
208 | conf.skip_reduce_stat = true; |
209 | } |
210 | } |
211 | |
212 | conf.dispatch_calc_stat.set_kernel_attr_suffix("CALC" ); |
213 | conf.dispatch_calc_stat.generate(); |
214 | |
215 | conf.dispatch_reduce_stat = compute_engine->create_dispatch(); |
216 | conf.dispatch_reduce_stat.define_dim("REDUCE_STAT_IC" , conf.ic); |
217 | conf.dispatch_reduce_stat.set_kernel_attr_suffix("REDUCE" ); |
218 | conf.dispatch_reduce_stat.generate(); |
219 | } |
220 | } |
221 | |
222 | return status::success; |
223 | } |
224 | |
225 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
226 | const bnorm_conf_t &conf, const offsets_t &off) { |
227 | kernel_ctx.set_data_type(conf.data_type); |
228 | |
229 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
230 | kernel_ctx.define_int("MB" , conf.mb); |
231 | kernel_ctx.define_int("IC" , conf.ic); |
232 | kernel_ctx.define_int("ID" , conf.id); |
233 | kernel_ctx.define_int("IH" , conf.ih); |
234 | kernel_ctx.define_int("IW" , conf.iw); |
235 | kernel_ctx.define_int("USE_16MB_UNROLL" , conf.use_16mb_unroll); |
236 | kernel_ctx.define_int("USE_NHWC" , conf.use_nhwc); |
237 | kernel_ctx.define_int("REDUCE_STAT_NBLOCKS" , conf.reduce_stat_nblocks); |
238 | kernel_ctx.define_int("MB_BLOCK" , conf.mb_block); |
239 | kernel_ctx.define_int("IC_BLOCK" , conf.ic_block); |
240 | |
241 | kernel_ctx.define_int("REDUCE_DIM_IDX" , conf.reduce_dim_idx); |
242 | kernel_ctx.define_int("REDUCE_DIM" , conf.reduce_dim); |
243 | |
244 | if (conf.is_forward) |
245 | kernel_ctx.define_int("IS_FWD" , 1); |
246 | else if (conf.is_backward) |
247 | kernel_ctx.define_int("IS_BWD" , 1); |
248 | |
249 | kernel_ctx.define_int("WITH_RELU" , conf.with_relu); |
250 | kernel_ctx.define_int("SAVE_STATS" , conf.save_stats); |
251 | kernel_ctx.define_int("IS_TRAINING" , conf.is_training); |
252 | kernel_ctx.define_int("FUSE_BN_RELU" , conf.fuse_norm_relu); |
253 | kernel_ctx.define_int("FUSE_BN_ADD_RELU" , conf.fuse_norm_add_relu); |
254 | kernel_ctx.define_int("CALCULATE_STATS" , conf.calculate_stats); |
255 | kernel_ctx.define_int("USE_SCALE" , conf.use_scale); |
256 | kernel_ctx.define_int("USE_SHIFT" , conf.use_shift); |
257 | kernel_ctx.define_int("CALCULATE_DIFF_STATS" , conf.calculate_diff_stats); |
258 | kernel_ctx.define_int("DIFF_SCALE" , conf.diff_scale); |
259 | kernel_ctx.define_int("DIFF_SHIFT" , conf.diff_shift); |
260 | kernel_ctx.define_int("VECTORIZE_CALC_STATS" , conf.vectorize_calc_stats); |
261 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
262 | kernel_ctx.define_int("VECT_SIZE" , conf.vect_size); |
263 | kernel_ctx.define_int("SKIP_REDUCE_STATS" , conf.skip_reduce_stat); |
264 | |
265 | def_offsets(off.src_off, kernel_ctx, "SRC" , conf.ndims); |
266 | |
267 | if (conf.data_type == data_type::s8) |
268 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
269 | |
270 | if (conf.calculate_stats || conf.is_backward) { |
271 | def_dispatch(kernel_ctx, conf.dispatch_calc_stat); |
272 | def_dispatch(kernel_ctx, conf.dispatch_reduce_stat); |
273 | } |
274 | def_dispatch(kernel_ctx, conf.dispatch); |
275 | |
276 | return status::success; |
277 | } |
278 | |
279 | status_t ref_batch_normalization_fwd_t::pd_t::init_conf(engine_t *engine) { |
280 | return init_conf_common(conf, off, this, engine); |
281 | } |
282 | |
283 | status_t ref_batch_normalization_fwd_t::pd_t::init_kernel_ctx( |
284 | compute::kernel_ctx_t &kernel_ctx) const { |
285 | return init_kernel_ctx_common(kernel_ctx, conf, off); |
286 | } |
287 | |
288 | void ref_batch_normalization_fwd_t::pd_t::init_scratchpad() { |
289 | if (conf.calculate_stats) { |
290 | |
291 | size_t size = 2 * conf.stat_ic; |
292 | |
293 | auto scratchpad = scratchpad_registry().registrar(); |
294 | scratchpad.book(memory_tracking::names::key_bnorm_reduction, size, |
295 | types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT); |
296 | } |
297 | } |
298 | |
299 | status_t ref_batch_normalization_fwd_t::execute_forward( |
300 | const exec_ctx_t &ctx) const { |
301 | |
302 | status_t status = status::success; |
303 | const auto &conf = pd()->conf; |
304 | |
305 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
306 | auto &src_add = CTX_IN_STORAGE(DNNL_ARG_SRC_1); |
307 | |
308 | auto &mean_ = pd()->stats_is_src() |
309 | ? CTX_IN_STORAGE(DNNL_ARG_MEAN) |
310 | : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_MEAN, status); |
311 | CHECK(status); |
312 | |
313 | auto &variance_ = pd()->stats_is_src() |
314 | ? CTX_IN_STORAGE(DNNL_ARG_VARIANCE) |
315 | : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_VARIANCE, status); |
316 | CHECK(status); |
317 | |
318 | auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); |
319 | auto &shift = CTX_IN_STORAGE(DNNL_ARG_SHIFT); |
320 | |
321 | auto &dst = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status); |
322 | CHECK(status); |
323 | auto &ws = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_WORKSPACE, status); |
324 | CHECK(status); |
325 | |
326 | auto *mean_ptr = &mean_; |
327 | auto *variance_ptr = &variance_; |
328 | |
329 | std::unique_ptr<memory_storage_t> temp_reduce = nullptr; |
330 | if (conf.calculate_stats) { |
331 | if (!conf.skip_reduce_stat || !conf.save_stats) { |
332 | temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage( |
333 | key_bnorm_reduction); |
334 | } |
335 | |
336 | if (!conf.save_stats) { |
337 | mean_ptr = temp_reduce.get(); |
338 | variance_ptr = temp_reduce.get(); |
339 | } |
340 | } |
341 | |
342 | auto &mean = *mean_ptr; |
343 | auto &variance = *variance_ptr; |
344 | |
345 | if (conf.calculate_stats) { |
346 | if (conf.skip_reduce_stat) { |
347 | compute::kernel_arg_list_t calc_var_arg_list; |
348 | calc_var_arg_list.set(0, src); |
349 | calc_var_arg_list.set(1, mean); |
350 | calc_var_arg_list.set(2, variance); |
351 | |
352 | auto nd_range_calc_var = conf.dispatch_calc_stat.nd_range(); |
353 | |
354 | status = parallel_for(ctx, nd_range_calc_var, |
355 | calculate_mean_variance_kernel_, calc_var_arg_list); |
356 | if (status != status::success) return status; |
357 | } else { |
358 | compute::kernel_arg_list_t calc_mean_arg_list; |
359 | calc_mean_arg_list.set(0, src); |
360 | calc_mean_arg_list.set(1, *temp_reduce); |
361 | |
362 | auto nd_range_calc_mean = conf.dispatch_calc_stat.nd_range(); |
363 | |
364 | status = parallel_for(ctx, nd_range_calc_mean, |
365 | calculate_mean_kernel_, calc_mean_arg_list); |
366 | if (status != status::success) return status; |
367 | |
368 | compute::kernel_arg_list_t reduce_mean_arg_list; |
369 | reduce_mean_arg_list.set(0, *temp_reduce); |
370 | reduce_mean_arg_list.set(1, mean); |
371 | |
372 | auto nd_range_reduce_mean = conf.dispatch_reduce_stat.nd_range(); |
373 | |
374 | status = parallel_for(ctx, nd_range_reduce_mean, |
375 | reduce_mean_kernel_, reduce_mean_arg_list); |
376 | if (status != status::success) return status; |
377 | |
378 | compute::kernel_arg_list_t calc_var_arg_list; |
379 | calc_var_arg_list.set(0, src); |
380 | calc_var_arg_list.set(1, mean); |
381 | calc_var_arg_list.set(2, *temp_reduce); |
382 | |
383 | auto nd_range_calc_var = conf.dispatch_calc_stat.nd_range(); |
384 | |
385 | status = parallel_for(ctx, nd_range_calc_var, |
386 | calculate_variance_kernel_, calc_var_arg_list); |
387 | if (status != status::success) return status; |
388 | |
389 | compute::kernel_arg_list_t reduce_var_arg_list; |
390 | reduce_var_arg_list.set(0, *temp_reduce); |
391 | reduce_var_arg_list.set(1, variance); |
392 | |
393 | auto nd_range_reduce_var = conf.dispatch_reduce_stat.nd_range(); |
394 | |
395 | status = parallel_for(ctx, nd_range_reduce_var, |
396 | reduce_variance_kernel_, reduce_var_arg_list); |
397 | if (status != status::success) return status; |
398 | } |
399 | } |
400 | |
401 | compute::kernel_arg_list_t arg_list; |
402 | arg_list.set(0, src); |
403 | arg_list.set(1, mean); |
404 | arg_list.set(2, variance); |
405 | arg_list.set(3, dst); |
406 | arg_list.set(4, scale); |
407 | arg_list.set(5, shift); |
408 | arg_list.set(6, ws); |
409 | arg_list.set(7, conf.eps); |
410 | arg_list.set(8, src_add); |
411 | |
412 | auto nd_range = conf.dispatch.nd_range(); |
413 | |
414 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
415 | |
416 | return status; |
417 | } |
418 | |
419 | status_t ref_batch_normalization_bwd_t::pd_t::init_conf(engine_t *engine) { |
420 | return init_conf_common(conf, off, this, engine); |
421 | } |
422 | |
423 | status_t ref_batch_normalization_bwd_t::pd_t::init_kernel_ctx( |
424 | compute::kernel_ctx_t &kernel_ctx) const { |
425 | return init_kernel_ctx_common(kernel_ctx, conf, off); |
426 | } |
427 | |
428 | void ref_batch_normalization_bwd_t::pd_t::init_scratchpad() { |
429 | size_t size; |
430 | if (conf.use_16mb_unroll) { |
431 | size = 2 * conf.reduce_stat_nblocks * conf.ic; |
432 | } else { |
433 | size = 2 * conf.stat_ic; |
434 | } |
435 | |
436 | auto scratchpad = scratchpad_registry().registrar(); |
437 | scratchpad.book(memory_tracking::names::key_bnorm_reduction, size, |
438 | types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT); |
439 | } |
440 | |
441 | status_t ref_batch_normalization_bwd_t::execute_backward( |
442 | const exec_ctx_t &ctx) const { |
443 | |
444 | status_t status = status::success; |
445 | const auto &conf = pd()->conf; |
446 | |
447 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
448 | auto &mean = CTX_IN_STORAGE(DNNL_ARG_MEAN); |
449 | auto &variance = CTX_IN_STORAGE(DNNL_ARG_VARIANCE); |
450 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
451 | auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); |
452 | auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); |
453 | |
454 | auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status); |
455 | CHECK(status); |
456 | auto &diff_src_add = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC_1, status); |
457 | CHECK(status); |
458 | auto &diff_scale_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SCALE, status); |
459 | CHECK(status); |
460 | auto &diff_shift_ = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SHIFT, status); |
461 | CHECK(status); |
462 | |
463 | std::unique_ptr<memory_storage_t> temp_reduce; |
464 | temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage( |
465 | key_bnorm_reduction); |
466 | |
467 | compute::kernel_arg_list_t calc_stats_arg_list; |
468 | calc_stats_arg_list.set(0, src); |
469 | calc_stats_arg_list.set(1, mean); |
470 | calc_stats_arg_list.set(2, diff_dst); |
471 | calc_stats_arg_list.set(3, ws); |
472 | calc_stats_arg_list.set(4, *temp_reduce); |
473 | |
474 | auto nd_range = conf.dispatch_calc_stat.nd_range(); |
475 | |
476 | status = parallel_for( |
477 | ctx, nd_range, calculate_stats_kernel_, calc_stats_arg_list); |
478 | if (status != status::success) return status; |
479 | |
480 | auto &diff_scale = !conf.diff_scale ? *temp_reduce : diff_scale_; |
481 | auto &diff_shift = !conf.diff_shift ? *temp_reduce : diff_shift_; |
482 | |
483 | compute::kernel_arg_list_t reduce_stats_arg_list; |
484 | reduce_stats_arg_list.set(0, *temp_reduce); |
485 | reduce_stats_arg_list.set(1, diff_scale); |
486 | reduce_stats_arg_list.set(2, diff_shift); |
487 | reduce_stats_arg_list.set(3, variance); |
488 | reduce_stats_arg_list.set(4, conf.eps); |
489 | |
490 | auto nd_range_reduce_stat = conf.dispatch_reduce_stat.nd_range(); |
491 | |
492 | status = parallel_for(ctx, nd_range_reduce_stat, reduce_stats_kernel_, |
493 | reduce_stats_arg_list); |
494 | if (status != status::success) return status; |
495 | |
496 | compute::kernel_arg_list_t arg_list; |
497 | arg_list.set(0, src); |
498 | arg_list.set(1, mean); |
499 | arg_list.set(2, variance); |
500 | arg_list.set(3, diff_dst); |
501 | arg_list.set(4, scale); |
502 | arg_list.set(5, ws); |
503 | arg_list.set(6, diff_src); |
504 | arg_list.set(7, diff_scale); |
505 | arg_list.set(8, diff_shift); |
506 | arg_list.set(9, conf.eps); |
507 | arg_list.set(10, diff_src_add); |
508 | |
509 | nd_range = conf.dispatch.nd_range(); |
510 | |
511 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
512 | |
513 | return status; |
514 | } |
515 | |
516 | } // namespace ocl |
517 | } // namespace gpu |
518 | } // namespace impl |
519 | } // namespace dnnl |
520 | |
521 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
522 | |