1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_layer_normalization.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | |
20 | #include "common/primitive_exec_types.hpp" |
21 | #include "common/scratchpad.hpp" |
22 | #include "gpu/ocl/ocl_utils.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | static status_t init_conf_common(lnorm_conf_t &conf, |
30 | const layer_normalization_pd_t *pd, engine_t *engine) { |
31 | using namespace dnnl::impl::format_tag; |
32 | |
33 | memory_desc_wrapper src_mdw(pd->src_md()); |
34 | memory_desc_wrapper stat_mdw(pd->stat_md()); |
35 | memory_desc_wrapper dst_mdw( |
36 | pd->is_fwd() ? pd->dst_md() : pd->diff_dst_md()); |
37 | |
38 | int ndims = src_mdw.ndims(); |
39 | |
40 | conf.data_type = src_mdw.data_type(); |
41 | conf.ndims = ndims; |
42 | conf.norm_axis = pd->norm_axis(); |
43 | |
44 | conf.src_md_info = memory_desc_info_t::create(src_mdw); |
45 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
46 | conf.stat_md_info = memory_desc_info_t::create(stat_mdw); |
47 | |
48 | conf.is_fwd = pd->is_fwd(); |
49 | |
50 | conf.vectorize_calc_stats = false; |
51 | conf.vect_dt_n = 1; |
52 | conf.sub_group_size = 1; |
53 | |
54 | int c_block = 1; |
55 | bool c_is_last_physical = false; |
56 | if (src_mdw.blocking_desc().inner_nblks > 0) { |
57 | c_block = src_mdw.blocking_desc() |
58 | .inner_blks[src_mdw.blocking_desc().inner_nblks - 1]; |
59 | c_is_last_physical |
60 | = src_mdw.blocking_desc().inner_idxs[ndims - 1] == ndims - 1; |
61 | } else { |
62 | c_is_last_physical = src_mdw.blocking_desc().strides[ndims - 1] == 1; |
63 | } |
64 | |
65 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
66 | conf.dispatch_scaleshift = compute_engine->create_dispatch(); |
67 | conf.dispatch_scaleshift_finalize = compute_engine->create_dispatch(); |
68 | conf.dispatch = compute_engine->create_dispatch( |
69 | pd->is_fwd() ? dst_mdw.md_ : src_mdw.md_); |
70 | const auto &dims = pd->is_fwd() ? src_mdw.padded_dims() : dst_mdw.dims(); |
71 | if (pd->is_fwd()) { |
72 | if ((conf.norm_axis % 16 == 0) && ndims < 4 && c_is_last_physical |
73 | && (c_block == 1 || (c_block % 16 == 0 && ndims == 2) |
74 | || src_mdw.is_dense())) { |
75 | conf.vectorize_calc_stats = true; |
76 | conf.sub_group_size = 16; |
77 | int vector_size = 8; |
78 | while (conf.norm_axis % (conf.sub_group_size * vector_size) != 0) { |
79 | vector_size /= 2; |
80 | } |
81 | while (c_block > 1 && vector_size * conf.sub_group_size > c_block) { |
82 | vector_size /= 2; |
83 | } |
84 | conf.vect_dt_n = vector_size; |
85 | } |
86 | for (int i = 0; i < 4; i++) { |
87 | int md_hint_idx = nstl::min(i, ndims - 1); |
88 | int dim = (i < ndims - 1) ? dims[i] : 1; |
89 | if (conf.vectorize_calc_stats && (i == ndims - 1)) { |
90 | dim = 16; |
91 | conf.dispatch.define_dim( |
92 | utils::format("X%d" , i), md_hint_idx, dim); |
93 | CHECK(conf.dispatch.vectorize_dim( |
94 | utils::format("X%d" , i), conf.sub_group_size)); |
95 | } else |
96 | conf.dispatch.define_dim( |
97 | utils::format("X%d" , i), md_hint_idx, dim); |
98 | } |
99 | } else { |
100 | conf.vectorize_bwd = false; |
101 | const int desired_sg_size = 16; |
102 | if (conf.norm_axis % desired_sg_size == 0 |
103 | && (src_mdw.matches_one_of_tag(ab, abc, abcd, abcde) |
104 | || (ndims == 2 && c_block % desired_sg_size == 0 |
105 | && c_is_last_physical))) { |
106 | conf.vectorize_bwd = true; |
107 | conf.sub_group_size = desired_sg_size; |
108 | conf.vect_dt_n = 8; |
109 | while (conf.norm_axis % (conf.sub_group_size * conf.vect_dt_n) |
110 | != 0) { |
111 | conf.vect_dt_n /= 2; |
112 | } |
113 | while (src_mdw.blocking_desc().inner_nblks > 0 |
114 | && c_block % (conf.sub_group_size * conf.vect_dt_n) != 0) { |
115 | conf.vect_dt_n /= 2; |
116 | } |
117 | } |
118 | for (int i = 0; i < 4; i++) { |
119 | int md_hint_idx = nstl::min(i, ndims - 1); |
120 | int dim = (i < ndims - 1) ? dims[i] : 1; |
121 | if (conf.vectorize_bwd && (i == ndims - 1)) { |
122 | conf.dispatch.define_dim(utils::format("X%d" , i), md_hint_idx, |
123 | conf.sub_group_size); |
124 | CHECK(conf.dispatch.vectorize_dim( |
125 | utils::format("X%d" , i), conf.sub_group_size)); |
126 | } else { |
127 | conf.dispatch.define_dim( |
128 | utils::format("X%d" , i), md_hint_idx, dim); |
129 | } |
130 | } |
131 | |
132 | int n_block = 1; |
133 | conf.n_chunk_size = 1; |
134 | conf.vector_size_scaleshift = 1; |
135 | conf.n_chunks = dims[0] / conf.n_chunk_size; |
136 | if (src_mdw.blocking_desc().inner_nblks == 2 |
137 | && src_mdw.blocking_desc().inner_idxs[0] == 0) { |
138 | n_block = src_mdw.blocking_desc().inner_blks[0]; |
139 | } |
140 | // Scaleshift vectorization is supported for tensors |
141 | // with shapes AxB, 1xBxC |
142 | conf.vectorize_bwd_scaleshift = conf.vectorize_bwd |
143 | && stat_mdw.matches_one_of_tag(a, ab) |
144 | && ((ndims == 2 |
145 | && (c_block == desired_sg_size |
146 | || src_mdw.matches_tag(ab))) |
147 | || (ndims == 3 && src_mdw.matches_tag(abc) |
148 | && dims[0] == 1)); |
149 | if (conf.vectorize_bwd_scaleshift) { |
150 | // Use partial reduction in order to increase number of used threads |
151 | conf.vector_size_scaleshift = c_block == desired_sg_size ? 8 : 1; |
152 | const int first_dim = ndims == 2 ? dims[0] : dims[1]; |
153 | while (n_block % conf.vector_size_scaleshift != 0 |
154 | || first_dim % conf.vector_size_scaleshift != 0) { |
155 | conf.vector_size_scaleshift /= 2; |
156 | } |
157 | // Experimentally selected values |
158 | const int max_first_dim_elems_per_wi = 32; |
159 | int desired_first_dim_block_reads |
160 | = max_first_dim_elems_per_wi / conf.vector_size_scaleshift; |
161 | while (first_dim |
162 | % (desired_first_dim_block_reads |
163 | * conf.vector_size_scaleshift) |
164 | != 0) { |
165 | desired_first_dim_block_reads /= 2; |
166 | } |
167 | while (first_dim |
168 | % (desired_first_dim_block_reads |
169 | * conf.vector_size_scaleshift) |
170 | != 0) { |
171 | conf.vector_size_scaleshift /= 2; |
172 | } |
173 | conf.n_chunk_size = desired_first_dim_block_reads |
174 | * conf.vector_size_scaleshift; |
175 | conf.n_chunks = first_dim / conf.n_chunk_size; |
176 | // Scaleshift kernel does partial reduction of N |
177 | conf.dispatch_scaleshift.define_dim("N" , conf.n_chunks); |
178 | conf.dispatch_scaleshift.define_dim("C" , pd->norm_axis()); |
179 | CHECK(conf.dispatch_scaleshift.vectorize_dim( |
180 | "C" , conf.sub_group_size)); |
181 | conf.dispatch_scaleshift.set_kernel_attr_suffix("SCALESHIFT" ); |
182 | conf.dispatch_scaleshift.generate(); |
183 | // Scaleshift finalize kernel reduces results of scaleshift kernel |
184 | conf.dispatch_scaleshift_finalize.define_dim( |
185 | "C_finalize" , pd->norm_axis()); |
186 | conf.dispatch_scaleshift_finalize.set_kernel_attr_suffix( |
187 | "SCALESHIFT_FINALIZE" ); |
188 | conf.dispatch_scaleshift_finalize.generate(); |
189 | } else { |
190 | conf.dispatch_scaleshift.define_dim("C" , pd->norm_axis()); |
191 | conf.dispatch_scaleshift.set_kernel_attr_suffix("SCALESHIFT" ); |
192 | conf.dispatch_scaleshift.generate(); |
193 | } |
194 | } |
195 | |
196 | conf.dispatch.generate(); |
197 | |
198 | conf.use_scale = pd->use_scale(); |
199 | conf.use_shift = pd->use_shift(); |
200 | |
201 | conf.calculate_stats = !pd->stats_are_src(); |
202 | conf.save_stats = pd->is_training(); |
203 | conf.eps = pd->desc()->layer_norm_epsilon; |
204 | |
205 | return status::success; |
206 | } |
207 | |
208 | static status_t init_kernel_ctx_common( |
209 | compute::kernel_ctx_t &kernel_ctx, const lnorm_conf_t &conf) { |
210 | kernel_ctx.set_data_type(conf.data_type); |
211 | |
212 | kernel_ctx.define_int("C" , conf.norm_axis); |
213 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
214 | kernel_ctx.define_int("USE_SCALE" , conf.use_scale); |
215 | kernel_ctx.define_int("USE_SHIFT" , conf.use_shift); |
216 | kernel_ctx.define_int("CALCULATE_STATS" , conf.calculate_stats); |
217 | kernel_ctx.define_int("SAVE_STATS" , conf.save_stats); |
218 | kernel_ctx.define_int("IS_FWD" , conf.is_fwd); |
219 | kernel_ctx.define_int("IS_BWD" , !conf.is_fwd); |
220 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
221 | kernel_ctx.define_int("VECTORIZE_CALC_STATS" , conf.vectorize_calc_stats); |
222 | kernel_ctx.define_int("VECTORIZE_BWD" , conf.vectorize_bwd); |
223 | kernel_ctx.define_int( |
224 | "VECTORIZE_BWD_SCALESHIFT" , conf.vectorize_bwd_scaleshift); |
225 | kernel_ctx.define_int("VECT_DT_N" , conf.vect_dt_n); |
226 | kernel_ctx.define_int( |
227 | "VECTOR_SIZE_SCALESHIFT" , conf.vector_size_scaleshift); |
228 | kernel_ctx.define_int("N_CHUNK_SIZE" , conf.n_chunk_size); |
229 | kernel_ctx.define_int("N_CHUNKS" , conf.n_chunks); |
230 | |
231 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
232 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
233 | def_memory_desc_info(kernel_ctx, conf.stat_md_info, "STAT" ); |
234 | |
235 | def_dispatch(kernel_ctx, conf.dispatch); |
236 | if (!conf.is_fwd) { |
237 | def_dispatch(kernel_ctx, conf.dispatch_scaleshift); |
238 | if (conf.vectorize_bwd_scaleshift) |
239 | def_dispatch(kernel_ctx, conf.dispatch_scaleshift_finalize); |
240 | } |
241 | |
242 | return status::success; |
243 | } |
244 | |
245 | status_t ref_layer_normalization_fwd_t::pd_t::init_conf(engine_t *engine) { |
246 | return init_conf_common(conf, this, engine); |
247 | } |
248 | |
249 | status_t ref_layer_normalization_fwd_t::pd_t::init_kernel_ctx( |
250 | compute::kernel_ctx_t &kernel_ctx) const { |
251 | return init_kernel_ctx_common(kernel_ctx, conf); |
252 | } |
253 | |
254 | status_t ref_layer_normalization_fwd_t::execute_forward( |
255 | const exec_ctx_t &ctx) const { |
256 | if (pd()->has_zero_dim_memory()) return status::success; |
257 | |
258 | const auto &conf = pd()->conf; |
259 | status_t status = status::success; |
260 | |
261 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
262 | auto &mean = pd()->stats_are_src() ? CTX_IN_STORAGE(DNNL_ARG_MEAN) |
263 | : CTX_OUT_STORAGE(DNNL_ARG_MEAN); |
264 | |
265 | auto &variance = pd()->stats_are_src() ? CTX_IN_STORAGE(DNNL_ARG_VARIANCE) |
266 | : CTX_OUT_STORAGE(DNNL_ARG_VARIANCE); |
267 | |
268 | auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); |
269 | auto &shift = CTX_IN_STORAGE(DNNL_ARG_SHIFT); |
270 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
271 | |
272 | compute::kernel_arg_list_t arg_list; |
273 | arg_list.set(0, src); |
274 | arg_list.set(1, mean); |
275 | arg_list.set(2, variance); |
276 | arg_list.set(3, dst); |
277 | arg_list.set(4, scale); |
278 | arg_list.set(5, shift); |
279 | arg_list.set(6, conf.eps); |
280 | |
281 | auto nd_range_kernel = conf.dispatch.nd_range(); |
282 | |
283 | status = parallel_for(ctx, nd_range_kernel, kernel_, arg_list); |
284 | |
285 | return status; |
286 | } |
287 | |
288 | status_t ref_layer_normalization_bwd_t::pd_t::init_conf(engine_t *engine) { |
289 | return init_conf_common(conf, this, engine); |
290 | } |
291 | |
292 | status_t ref_layer_normalization_bwd_t::pd_t::init_kernel_ctx( |
293 | compute::kernel_ctx_t &kernel_ctx) const { |
294 | return init_kernel_ctx_common(kernel_ctx, conf); |
295 | } |
296 | |
297 | void ref_layer_normalization_bwd_t::pd_t::init_scratchpad() { |
298 | const size_t size = conf.n_chunks * conf.norm_axis * 2; |
299 | auto scratchpad = scratchpad_registry().registrar(); |
300 | scratchpad.book(memory_tracking::names::key_lnorm_reduction, size, |
301 | types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT); |
302 | } |
303 | |
304 | status_t ref_layer_normalization_bwd_t::execute_backward( |
305 | const exec_ctx_t &ctx) const { |
306 | if (pd()->has_zero_dim_memory()) return status::success; |
307 | |
308 | status_t status = status::success; |
309 | |
310 | const auto &conf = pd()->conf; |
311 | |
312 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
313 | auto &mean = CTX_IN_STORAGE(DNNL_ARG_MEAN); |
314 | auto &variance = CTX_IN_STORAGE(DNNL_ARG_VARIANCE); |
315 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
316 | auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); |
317 | |
318 | auto &diff_src = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status); |
319 | CHECK(status); |
320 | auto &diff_scale = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SCALE); |
321 | auto &diff_shift = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SHIFT); |
322 | |
323 | if (conf.use_scale || conf.use_shift) { |
324 | std::unique_ptr<memory_storage_t> temp_reduce; |
325 | compute::kernel_arg_list_t arg_list; |
326 | arg_list.set(0, src); |
327 | arg_list.set(1, mean); |
328 | arg_list.set(2, variance); |
329 | arg_list.set(3, diff_dst); |
330 | if (conf.vectorize_bwd_scaleshift) { |
331 | temp_reduce = ctx.get_scratchpad_grantor().get_memory_storage( |
332 | memory_tracking::names::key_lnorm_reduction); |
333 | arg_list.set(4, *temp_reduce); |
334 | arg_list.set(5, *temp_reduce); |
335 | } else { |
336 | arg_list.set(4, diff_scale); |
337 | arg_list.set(5, diff_shift); |
338 | } |
339 | arg_list.set(6, conf.eps); |
340 | |
341 | auto nd_range = conf.dispatch_scaleshift.nd_range(); |
342 | status = parallel_for(ctx, nd_range, kernel_scaleshift_, arg_list); |
343 | if (status != status::success) return status; |
344 | |
345 | if (conf.vectorize_bwd_scaleshift) { |
346 | compute::kernel_arg_list_t arg_list_final; |
347 | arg_list_final.set(0, *temp_reduce); |
348 | arg_list_final.set(1, diff_scale); |
349 | arg_list_final.set(2, diff_shift); |
350 | |
351 | auto nd_range_finalize |
352 | = conf.dispatch_scaleshift_finalize.nd_range(); |
353 | status = parallel_for(ctx, nd_range_finalize, |
354 | kernel_scaleshift_finalize_, arg_list_final); |
355 | if (status != status::success) return status; |
356 | } |
357 | } |
358 | |
359 | compute::kernel_arg_list_t arg_list; |
360 | arg_list.set(0, src); |
361 | arg_list.set(1, mean); |
362 | arg_list.set(2, variance); |
363 | arg_list.set(3, diff_dst); |
364 | arg_list.set(4, scale); |
365 | arg_list.set(5, diff_src); |
366 | arg_list.set(6, conf.eps); |
367 | |
368 | auto nd_range_kernel = conf.dispatch.nd_range(); |
369 | |
370 | status = parallel_for(ctx, nd_range_kernel, kernel_, arg_list); |
371 | |
372 | return status; |
373 | } |
374 | |
375 | } // namespace ocl |
376 | } // namespace gpu |
377 | } // namespace impl |
378 | } // namespace dnnl |
379 | |