1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_layer_normalization.hpp"
18#include "common/c_types_map.hpp"
19
20#include "common/primitive_exec_types.hpp"
21#include "common/scratchpad.hpp"
22#include "gpu/ocl/ocl_utils.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29static status_t init_conf_common(lnorm_conf_t &conf,
30 const layer_normalization_pd_t *pd, engine_t *engine) {
31 using namespace dnnl::impl::format_tag;
32
33 memory_desc_wrapper src_mdw(pd->src_md());
34 memory_desc_wrapper stat_mdw(pd->stat_md());
35 memory_desc_wrapper dst_mdw(
36 pd->is_fwd() ? pd->dst_md() : pd->diff_dst_md());
37
38 int ndims = src_mdw.ndims();
39
40 conf.data_type = src_mdw.data_type();
41 conf.ndims = ndims;
42 conf.norm_axis = pd->norm_axis();
43
44 conf.src_md_info = memory_desc_info_t::create(src_mdw);
45 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
46 conf.stat_md_info = memory_desc_info_t::create(stat_mdw);
47
48 conf.is_fwd = pd->is_fwd();
49
50 conf.vectorize_calc_stats = false;
51 conf.vect_dt_n = 1;
52 conf.sub_group_size = 1;
53
54 int c_block = 1;
55 bool c_is_last_physical = false;
56 if (src_mdw.blocking_desc().inner_nblks > 0) {
57 c_block = src_mdw.blocking_desc()
58 .inner_blks[src_mdw.blocking_desc().inner_nblks - 1];
59 c_is_last_physical
60 = src_mdw.blocking_desc().inner_idxs[ndims - 1] == ndims - 1;
61 } else {
62 c_is_last_physical = src_mdw.blocking_desc().strides[ndims - 1] == 1;
63 }
64
65 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
66 conf.dispatch_scaleshift = compute_engine->create_dispatch();
67 conf.dispatch_scaleshift_finalize = compute_engine->create_dispatch();
68 conf.dispatch = compute_engine->create_dispatch(
69 pd->is_fwd() ? dst_mdw.md_ : src_mdw.md_);
70 const auto &dims = pd->is_fwd() ? src_mdw.padded_dims() : dst_mdw.dims();
71 if (pd->is_fwd()) {
72 if ((conf.norm_axis % 16 == 0) && ndims < 4 && c_is_last_physical
73 && (c_block == 1 || (c_block % 16 == 0 && ndims == 2)
74 || src_mdw.is_dense())) {
75 conf.vectorize_calc_stats = true;
76 conf.sub_group_size = 16;
77 int vector_size = 8;
78 while (conf.norm_axis % (conf.sub_group_size * vector_size) != 0) {
79 vector_size /= 2;
80 }
81 while (c_block > 1 && vector_size * conf.sub_group_size > c_block) {
82 vector_size /= 2;
83 }
84 conf.vect_dt_n = vector_size;
85 }
86 for (int i = 0; i < 4; i++) {
87 int md_hint_idx = nstl::min(i, ndims - 1);
88 int dim = (i < ndims - 1) ? dims[i] : 1;
89 if (conf.vectorize_calc_stats && (i == ndims - 1)) {
90 dim = 16;
91 conf.dispatch.define_dim(
92 utils::format("X%d", i), md_hint_idx, dim);
93 CHECK(conf.dispatch.vectorize_dim(
94 utils::format("X%d", i), conf.sub_group_size));
95 } else
96 conf.dispatch.define_dim(
97 utils::format("X%d", i), md_hint_idx, dim);
98 }
99 } else {
100 conf.vectorize_bwd = false;
101 const int desired_sg_size = 16;
102 if (conf.norm_axis % desired_sg_size == 0
103 && (src_mdw.matches_one_of_tag(ab, abc, abcd, abcde)
104 || (ndims == 2 && c_block % desired_sg_size == 0
105 && c_is_last_physical))) {
106 conf.vectorize_bwd = true;
107 conf.sub_group_size = desired_sg_size;
108 conf.vect_dt_n = 8;
109 while (conf.norm_axis % (conf.sub_group_size * conf.vect_dt_n)
110 != 0) {
111 conf.vect_dt_n /= 2;
112 }
113 while (src_mdw.blocking_desc().inner_nblks > 0
114 && c_block % (conf.sub_group_size * conf.vect_dt_n) != 0) {
115 conf.vect_dt_n /= 2;
116 }
117 }
118 for (int i = 0; i < 4; i++) {
119 int md_hint_idx = nstl::min(i, ndims - 1);
120 int dim = (i < ndims - 1) ? dims[i] : 1;
121 if (conf.vectorize_bwd && (i == ndims - 1)) {
122 conf.dispatch.define_dim(utils::format("X%d", i), md_hint_idx,
123 conf.sub_group_size);
124 CHECK(conf.dispatch.vectorize_dim(
125 utils::format("X%d", i), conf.sub_group_size));
126 } else {
127 conf.dispatch.define_dim(
128 utils::format("X%d", i), md_hint_idx, dim);
129 }
130 }
131
132 int n_block = 1;
133 conf.n_chunk_size = 1;
134 conf.vector_size_scaleshift = 1;
135 conf.n_chunks = dims[0] / conf.n_chunk_size;
136 if (src_mdw.blocking_desc().inner_nblks == 2
137 && src_mdw.blocking_desc().inner_idxs[0] == 0) {
138 n_block = src_mdw.blocking_desc().inner_blks[0];
139 }
140 // Scaleshift vectorization is supported for tensors
141 // with shapes AxB, 1xBxC
142 conf.vectorize_bwd_scaleshift = conf.vectorize_bwd
143 && stat_mdw.matches_one_of_tag(a, ab)
144 && ((ndims == 2
145 && (c_block == desired_sg_size
146 || src_mdw.matches_tag(ab)))
147 || (ndims == 3 && src_mdw.matches_tag(abc)
148 && dims[0] == 1));
149 if (conf.vectorize_bwd_scaleshift) {
150 // Use partial reduction in order to increase number of used threads
151 conf.vector_size_scaleshift = c_block == desired_sg_size ? 8 : 1;
152 const int first_dim = ndims == 2 ? dims[0] : dims[1];
153 while (n_block % conf.vector_size_scaleshift != 0
154 || first_dim % conf.vector_size_scaleshift != 0) {
155 conf.vector_size_scaleshift /= 2;
156 }
157 // Experimentally selected values
158 const int max_first_dim_elems_per_wi = 32;
159 int desired_first_dim_block_reads
160 = max_first_dim_elems_per_wi / conf.vector_size_scaleshift;
161 while (first_dim
162 % (desired_first_dim_block_reads
163 * conf.vector_size_scaleshift)
164 != 0) {
165 desired_first_dim_block_reads /= 2;
166 }
167 while (first_dim
168 % (desired_first_dim_block_reads
169 * conf.vector_size_scaleshift)
170 != 0) {
171 conf.vector_size_scaleshift /= 2;
172 }
173 conf.n_chunk_size = desired_first_dim_block_reads
174 * conf.vector_size_scaleshift;
175 conf.n_chunks = first_dim / conf.n_chunk_size;
176 // Scaleshift kernel does partial reduction of N
177 conf.dispatch_scaleshift.define_dim("N", conf.n_chunks);
178 conf.dispatch_scaleshift.define_dim("C", pd->norm_axis());
179 CHECK(conf.dispatch_scaleshift.vectorize_dim(
180 "C", conf.sub_group_size));
181 conf.dispatch_scaleshift.set_kernel_attr_suffix("SCALESHIFT");
182 conf.dispatch_scaleshift.generate();
183 // Scaleshift finalize kernel reduces results of scaleshift kernel
184 conf.dispatch_scaleshift_finalize.define_dim(
185 "C_finalize", pd->norm_axis());
186 conf.dispatch_scaleshift_finalize.set_kernel_attr_suffix(
187 "SCALESHIFT_FINALIZE");
188 conf.dispatch_scaleshift_finalize.generate();
189 } else {
190 conf.dispatch_scaleshift.define_dim("C", pd->norm_axis());
191 conf.dispatch_scaleshift.set_kernel_attr_suffix("SCALESHIFT");
192 conf.dispatch_scaleshift.generate();
193 }
194 }
195
196 conf.dispatch.generate();
197
198 conf.use_scale = pd->use_scale();
199 conf.use_shift = pd->use_shift();
200
201 conf.calculate_stats = !pd->stats_are_src();
202 conf.save_stats = pd->is_training();
203 conf.eps = pd->desc()->layer_norm_epsilon;
204
205 return status::success;
206}
207
208static status_t init_kernel_ctx_common(
209 compute::kernel_ctx_t &kernel_ctx, const lnorm_conf_t &conf) {
210 kernel_ctx.set_data_type(conf.data_type);
211
212 kernel_ctx.define_int("C", conf.norm_axis);
213 kernel_ctx.define_int("NDIMS", conf.ndims);
214 kernel_ctx.define_int("USE_SCALE", conf.use_scale);
215 kernel_ctx.define_int("USE_SHIFT", conf.use_shift);
216 kernel_ctx.define_int("CALCULATE_STATS", conf.calculate_stats);
217 kernel_ctx.define_int("SAVE_STATS", conf.save_stats);
218 kernel_ctx.define_int("IS_FWD", conf.is_fwd);
219 kernel_ctx.define_int("IS_BWD", !conf.is_fwd);
220 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
221 kernel_ctx.define_int("VECTORIZE_CALC_STATS", conf.vectorize_calc_stats);
222 kernel_ctx.define_int("VECTORIZE_BWD", conf.vectorize_bwd);
223 kernel_ctx.define_int(
224 "VECTORIZE_BWD_SCALESHIFT", conf.vectorize_bwd_scaleshift);
225 kernel_ctx.define_int("VECT_DT_N", conf.vect_dt_n);
226 kernel_ctx.define_int(
227 "VECTOR_SIZE_SCALESHIFT", conf.vector_size_scaleshift);
228 kernel_ctx.define_int("N_CHUNK_SIZE", conf.n_chunk_size);
229 kernel_ctx.define_int("N_CHUNKS", conf.n_chunks);
230
231 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
232 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
233 def_memory_desc_info(kernel_ctx, conf.stat_md_info, "STAT");
234
235 def_dispatch(kernel_ctx, conf.dispatch);
236 if (!conf.is_fwd) {
237 def_dispatch(kernel_ctx, conf.dispatch_scaleshift);
238 if (conf.vectorize_bwd_scaleshift)
239 def_dispatch(kernel_ctx, conf.dispatch_scaleshift_finalize);
240 }
241
242 return status::success;
243}
244
245status_t ref_layer_normalization_fwd_t::pd_t::init_conf(engine_t *engine) {
246 return init_conf_common(conf, this, engine);
247}
248
249status_t ref_layer_normalization_fwd_t::pd_t::init_kernel_ctx(
250 compute::kernel_ctx_t &kernel_ctx) const {
251 return init_kernel_ctx_common(kernel_ctx, conf);
252}
253
254status_t ref_layer_normalization_fwd_t::execute_forward(
255 const exec_ctx_t &ctx) const {
256 if (pd()->has_zero_dim_memory()) return status::success;
257
258 const auto &conf = pd()->conf;
259 status_t status = status::success;
260
261 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
262 auto &mean = pd()->stats_are_src() ? CTX_IN_STORAGE(DNNL_ARG_MEAN)
263 : CTX_OUT_STORAGE(DNNL_ARG_MEAN);
264
265 auto &variance = pd()->stats_are_src() ? CTX_IN_STORAGE(DNNL_ARG_VARIANCE)
266 : CTX_OUT_STORAGE(DNNL_ARG_VARIANCE);
267
268 auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE);
269 auto &shift = CTX_IN_STORAGE(DNNL_ARG_SHIFT);
270 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
271
272 compute::kernel_arg_list_t arg_list;
273 arg_list.set(0, src);
274 arg_list.set(1, mean);
275 arg_list.set(2, variance);
276 arg_list.set(3, dst);
277 arg_list.set(4, scale);
278 arg_list.set(5, shift);
279 arg_list.set(6, conf.eps);
280
281 auto nd_range_kernel = conf.dispatch.nd_range();
282
283 status = parallel_for(ctx, nd_range_kernel, kernel_, arg_list);
284
285 return status;
286}
287
288status_t ref_layer_normalization_bwd_t::pd_t::init_conf(engine_t *engine) {
289 return init_conf_common(conf, this, engine);
290}
291
292status_t ref_layer_normalization_bwd_t::pd_t::init_kernel_ctx(
293 compute::kernel_ctx_t &kernel_ctx) const {
294 return init_kernel_ctx_common(kernel_ctx, conf);
295}
296
297void ref_layer_normalization_bwd_t::pd_t::init_scratchpad() {
298 const size_t size = conf.n_chunks * conf.norm_axis * 2;
299 auto scratchpad = scratchpad_registry().registrar();
300 scratchpad.book(memory_tracking::names::key_lnorm_reduction, size,
301 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
302}
303
304status_t ref_layer_normalization_bwd_t::execute_backward(
305 const exec_ctx_t &ctx) const {
306 if (pd()->has_zero_dim_memory()) return status::success;
307
308 status_t status = status::success;
309
310 const auto &conf = pd()->conf;
311
312 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
313 auto &mean = CTX_IN_STORAGE(DNNL_ARG_MEAN);
314 auto &variance = CTX_IN_STORAGE(DNNL_ARG_VARIANCE);
315 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
316 auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE);
317
318 auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status);
319 CHECK(status);
320 auto &diff_scale = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SCALE);
321 auto &diff_shift = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SHIFT);
322
323 if (conf.use_scale || conf.use_shift) {
324 std::unique_ptr<memory_storage_t> temp_reduce;
325 compute::kernel_arg_list_t arg_list;
326 arg_list.set(0, src);
327 arg_list.set(1, mean);
328 arg_list.set(2, variance);
329 arg_list.set(3, diff_dst);
330 if (conf.vectorize_bwd_scaleshift) {
331 temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage(
332 memory_tracking::names::key_lnorm_reduction);
333 arg_list.set(4, *temp_reduce);
334 arg_list.set(5, *temp_reduce);
335 } else {
336 arg_list.set(4, diff_scale);
337 arg_list.set(5, diff_shift);
338 }
339 arg_list.set(6, conf.eps);
340
341 auto nd_range = conf.dispatch_scaleshift.nd_range();
342 status = parallel_for(ctx, nd_range, kernel_scaleshift_, arg_list);
343 if (status != status::success) return status;
344
345 if (conf.vectorize_bwd_scaleshift) {
346 compute::kernel_arg_list_t arg_list_final;
347 arg_list_final.set(0, *temp_reduce);
348 arg_list_final.set(1, diff_scale);
349 arg_list_final.set(2, diff_shift);
350
351 auto nd_range_finalize
352 = conf.dispatch_scaleshift_finalize.nd_range();
353 status = parallel_for(ctx, nd_range_finalize,
354 kernel_scaleshift_finalize_, arg_list_final);
355 if (status != status::success) return status;
356 }
357 }
358
359 compute::kernel_arg_list_t arg_list;
360 arg_list.set(0, src);
361 arg_list.set(1, mean);
362 arg_list.set(2, variance);
363 arg_list.set(3, diff_dst);
364 arg_list.set(4, scale);
365 arg_list.set(5, diff_src);
366 arg_list.set(6, conf.eps);
367
368 auto nd_range_kernel = conf.dispatch.nd_range();
369
370 status = parallel_for(ctx, nd_range_kernel, kernel_, arg_list);
371
372 return status;
373}
374
375} // namespace ocl
376} // namespace gpu
377} // namespace impl
378} // namespace dnnl
379