1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_PRIMITIVE_CONF_HPP |
18 | #define GPU_PRIMITIVE_CONF_HPP |
19 | |
20 | #include <stdint.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/memory_desc_wrapper.hpp" |
24 | #include "common/memory_storage.hpp" |
25 | #include "common/primitive_attr.hpp" |
26 | #include "common/primitive_exec_types.hpp" |
27 | #include "common/utils.hpp" |
28 | #include "gpu/compute/compute.hpp" |
29 | #include "gpu/gpu_eltwise_pd.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | |
35 | #define MAX_NDIMS 6 |
36 | #define MAX_POST_OPS_SUPPORTED 32 |
37 | |
38 | inline bool memory_desc_ndims_ok(const memory_desc_t *md) { |
39 | return md->ndims > MAX_NDIMS; |
40 | } |
41 | |
42 | template <typename T, typename... Rest> |
43 | bool memory_desc_ndims_ok(const T *first, const Rest *... rest) { |
44 | return memory_desc_ndims_ok(first) || memory_desc_ndims_ok(rest...); |
45 | } |
46 | |
47 | inline dim_t get_attr_oscales_count(int mask, const memory_desc_wrapper &md) { |
48 | dim_t count = 1; |
49 | if (mask == 0) return count; |
50 | |
51 | for (int d = 0; d < md.ndims(); d++) { |
52 | const int dim_mask = 1 << d; |
53 | if (dim_mask & mask) count *= md.dims()[d]; |
54 | } |
55 | |
56 | return count; |
57 | } |
58 | |
59 | struct memory_desc_info_t { |
60 | // Max levels of blocking |
61 | static const int max_nlevels = 3; |
62 | |
63 | int ndims; |
64 | data_type_t data_type; |
65 | |
66 | int offset0; |
67 | int dims[MAX_NDIMS]; |
68 | int padded_dims[MAX_NDIMS]; |
69 | |
70 | int nlevels; |
71 | int blocks[MAX_NDIMS][max_nlevels + 1]; |
72 | int strides[MAX_NDIMS][max_nlevels + 1]; |
73 | |
74 | static memory_desc_info_t create(const memory_desc_wrapper &mdw) { |
75 | using namespace format_tag; |
76 | |
77 | auto md_info = memory_desc_info_t(); |
78 | |
79 | md_info.nlevels = 2; |
80 | |
81 | md_info.ndims = mdw.ndims(); |
82 | md_info.data_type = mdw.data_type(); |
83 | md_info.offset0 = mdw.offset0(); |
84 | |
85 | auto &blk = mdw.blocking_desc(); |
86 | dim_t blk_stride |
87 | = utils::array_product(blk.inner_blks, blk.inner_nblks); |
88 | |
89 | for (int d = 0; d < mdw.ndims(); ++d) { |
90 | utils::array_set(md_info.blocks[d], 1, md_info.nlevels + 1); |
91 | utils::array_set(md_info.strides[d], 0, md_info.nlevels + 1); |
92 | } |
93 | |
94 | for (int d = 0; d < mdw.ndims(); ++d) { |
95 | md_info.dims[d] = mdw.dims()[d]; |
96 | md_info.padded_dims[d] = mdw.padded_dims()[d]; |
97 | md_info.strides[d][0] = blk.strides[d]; |
98 | } |
99 | |
100 | int levels[MAX_NDIMS] = {0}; |
101 | for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { |
102 | int d = blk.inner_idxs[iblk]; |
103 | ++levels[d]; |
104 | |
105 | md_info.blocks[d][levels[d]] = blk.inner_blks[iblk]; |
106 | blk_stride /= blk.inner_blks[iblk]; |
107 | md_info.strides[d][levels[d]] = blk_stride; |
108 | } |
109 | return md_info; |
110 | } |
111 | }; |
112 | |
113 | struct attr_info_t { |
114 | static attr_info_t create(const primitive_attr_t *attr) { |
115 | const auto &po = attr->post_ops_; |
116 | |
117 | attr_info_t attr_info; |
118 | |
119 | attr_info.binary_idx = po.find(primitive_kind::binary); |
120 | attr_info.with_binary = (attr_info.binary_idx != -1); |
121 | |
122 | // Eltwise |
123 | attr_info.eltwise_idx = po.find(primitive_kind::eltwise); |
124 | attr_info.with_eltwise = (attr_info.eltwise_idx != -1); |
125 | |
126 | if (attr_info.with_eltwise) { |
127 | auto &eltwise = po.entry_[attr_info.eltwise_idx].eltwise; |
128 | attr_info.eltwise_alg = eltwise.alg; |
129 | attr_info.eltwise_scale = eltwise.scale; |
130 | attr_info.eltwise_alpha = eltwise.alpha; |
131 | attr_info.eltwise_beta = eltwise.beta; |
132 | } else { |
133 | attr_info.eltwise_alg = alg_kind::undef; |
134 | attr_info.eltwise_scale = 1.0f; |
135 | attr_info.eltwise_alpha = 1.0f; |
136 | attr_info.eltwise_beta = 0.0f; |
137 | } |
138 | |
139 | // Sum |
140 | attr_info.sum_idx = po.find(primitive_kind::sum); |
141 | attr_info.sum_scale = (attr_info.sum_idx != -1 |
142 | ? po.entry_[attr_info.sum_idx].sum.scale |
143 | : 0.0f); |
144 | attr_info.sum_data_type = (attr_info.sum_idx != -1) |
145 | ? po.entry_[attr_info.sum_idx].sum.dt |
146 | : dnnl_data_type_undef; |
147 | attr_info.with_sum |
148 | = (attr_info.sum_idx != -1) && (attr_info.sum_scale != 0.0f); |
149 | |
150 | // Output scales |
151 | attr_info.with_oscales |
152 | = !attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); |
153 | |
154 | const auto &scales_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
155 | attr_info.with_common_oscales |
156 | = attr_info.with_oscales && (scales_mask == 0); |
157 | attr_info.with_per_oc_oscales |
158 | = attr_info.with_oscales && (scales_mask == (1 << 1)); |
159 | |
160 | attr_info.with_runtime_oscales = !attr->output_scales_.defined(); |
161 | |
162 | const auto &src0_scales = attr->scales_.get(DNNL_ARG_SRC_0); |
163 | attr_info.with_src0_scale = !src0_scales.has_default_values(); |
164 | assert(src0_scales.mask_ == 0); |
165 | |
166 | const auto &src1_scales = attr->scales_.get(DNNL_ARG_SRC_1); |
167 | attr_info.with_src1_scale = !src1_scales.has_default_values(); |
168 | assert(src1_scales.mask_ == 0); |
169 | |
170 | const auto &src_scales = attr->scales_.get(DNNL_ARG_SRC); |
171 | attr_info.with_src_scales = !src_scales.has_default_values(); |
172 | assert(src_scales.mask_ == 0); |
173 | |
174 | const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS); |
175 | attr_info.with_wei_scales = !wei_scales.has_default_values(); |
176 | attr_info.wei_scales_mask = wei_scales.mask_; |
177 | |
178 | const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST); |
179 | attr_info.with_dst_scales = !dst_scales.has_default_values(); |
180 | assert(dst_scales.mask_ == 0); |
181 | |
182 | // zero points |
183 | const auto &zp = attr->zero_points_; |
184 | attr_info.with_src_zpoints = !zp.has_default_values(DNNL_ARG_SRC); |
185 | attr_info.with_wei_zpoints = !zp.has_default_values(DNNL_ARG_WEIGHTS); |
186 | attr_info.with_dst_zpoints = !zp.has_default_values(DNNL_ARG_DST); |
187 | |
188 | attr_info.with_per_ic_src_zpoints = attr_info.with_src_zpoints |
189 | && !zp.defined(DNNL_ARG_SRC) && !zp.common(DNNL_ARG_SRC); |
190 | |
191 | attr_info.with_per_oc_dst_zpoints = attr_info.with_dst_zpoints |
192 | && !zp.defined(DNNL_ARG_DST) && !zp.common(DNNL_ARG_DST); |
193 | |
194 | attr_info.initialized = true; |
195 | return attr_info; |
196 | } |
197 | |
198 | bool initialized = false; |
199 | |
200 | bool with_binary; |
201 | bool with_eltwise; |
202 | int eltwise_idx; |
203 | int binary_idx; |
204 | alg_kind_t eltwise_alg; |
205 | float eltwise_scale; |
206 | float eltwise_alpha; |
207 | float eltwise_beta; |
208 | |
209 | bool with_sum; |
210 | int sum_idx; |
211 | float sum_scale; |
212 | data_type_t sum_data_type; |
213 | |
214 | bool with_oscales; |
215 | bool with_common_oscales; |
216 | bool with_per_oc_oscales; |
217 | bool with_runtime_oscales; |
218 | |
219 | bool with_src0_scale; |
220 | bool with_src1_scale; |
221 | bool with_src_scales; |
222 | bool with_wei_scales; |
223 | bool with_dst_scales; |
224 | bool wei_scales_mask; |
225 | |
226 | bool with_src_zpoints; |
227 | bool with_wei_zpoints; |
228 | bool with_dst_zpoints; |
229 | bool with_per_ic_src_zpoints; |
230 | bool with_per_oc_dst_zpoints; |
231 | }; |
232 | |
233 | struct offsets_t { |
234 | int src_off[4][MAX_NDIMS]; |
235 | int wei_off[4][MAX_NDIMS]; |
236 | int dst_off[4][MAX_NDIMS]; |
237 | }; |
238 | |
239 | struct rnn_offsets_t { |
240 | int src_layer_off[4][MAX_NDIMS]; |
241 | int src_iter_off[4][MAX_NDIMS]; |
242 | int src_iter_c_off[4][MAX_NDIMS]; |
243 | int weights_layer_off[4][MAX_NDIMS]; |
244 | int weights_iter_off[4][MAX_NDIMS]; |
245 | int bias_off[4][MAX_NDIMS]; |
246 | int dst_layer_off[4][MAX_NDIMS]; |
247 | int dst_iter_off[4][MAX_NDIMS]; |
248 | int dst_iter_c_off[4][MAX_NDIMS]; |
249 | int diff_src_layer_off[4][MAX_NDIMS]; |
250 | int diff_src_iter_off[4][MAX_NDIMS]; |
251 | int diff_src_iter_c_off[4][MAX_NDIMS]; |
252 | int diff_weights_layer_off[4][MAX_NDIMS]; |
253 | int diff_weights_iter_off[4][MAX_NDIMS]; |
254 | int diff_bias_off[4][MAX_NDIMS]; |
255 | int diff_dst_layer_off[4][MAX_NDIMS]; |
256 | int diff_dst_iter_off[4][MAX_NDIMS]; |
257 | int diff_dst_iter_c_off[4][MAX_NDIMS]; |
258 | int ws_off[4][MAX_NDIMS]; |
259 | }; |
260 | |
261 | // Convolution |
262 | enum conv_version_t { |
263 | ver_unused, |
264 | ver_1stconv, |
265 | ver_16mb16c, |
266 | ver_32mb16c, |
267 | ver_32mb32c, |
268 | ver_32c, |
269 | ver_8ow16c, |
270 | ver_nhwc, |
271 | ver_nchw, |
272 | ver_mb_block, |
273 | ver_ow_block, |
274 | |
275 | // Xe_HP-specific versions. |
276 | ver_v1, |
277 | ver_v2 |
278 | }; |
279 | |
280 | struct conv_conf_t { |
281 | prop_kind_t prop_kind; |
282 | |
283 | int ndims; |
284 | int mb; |
285 | int ngroups, ic, oc; |
286 | int ngroups_without_padding, oc_without_padding, ic_without_padding; |
287 | int id, ih, iw, od, oh, ow; |
288 | int f_pad, l_pad, t_pad; |
289 | int back_pad, r_pad, b_pad; |
290 | int kd, kh, kw, kwb; |
291 | int stride_d, stride_h, stride_w; |
292 | int dilate_d, dilate_h, dilate_w; |
293 | |
294 | int sp_block, sp; |
295 | int od_block, oh_block, ow_block; |
296 | int id_block, ih_block, iw_block; |
297 | int oc_block, ic_block, nchunk; |
298 | int omb; |
299 | int odb, ohb, owb; |
300 | int icb; |
301 | int ocb; |
302 | int osp_chunk, mb_chunk, mb_block; |
303 | int iw_tail; |
304 | size_t wei_slm_size, src_slm_size, dst_slm_size; |
305 | int sub_group_size; |
306 | |
307 | size_t gws_d[3], lws_d[3]; |
308 | // Original global work sizes, before applying rounding in case when |
309 | // non-uniform work-groups are not supported. |
310 | size_t gws_orig_d[3]; |
311 | compute::dispatch_t dispatch; |
312 | |
313 | bool with_bias, with_groups; |
314 | |
315 | attr_info_t attr_info; |
316 | |
317 | bool is_depthwise; |
318 | bool is_nhwc; |
319 | bool reorder_wei = false; |
320 | bool reorder_bias = false; |
321 | int ver; |
322 | format_tag_t src_tag, dst_tag, wei_tag; |
323 | bool is_nchw; |
324 | bool is_src_nchw, is_src_nhwc; |
325 | bool is_dst_nhwc; |
326 | |
327 | int tile_size; |
328 | int wino_m; |
329 | int wino_r; |
330 | int wino_ih, wino_oh; |
331 | int wino_iw, wino_ow; |
332 | int wino_ic; |
333 | int wino_oc; |
334 | int wino_ic_block; |
335 | int wino_oc_block; |
336 | int vect_size; |
337 | size_t U_gws_d[3], U_lws_d[3]; |
338 | size_t V_gws_d[3], V_lws_d[3]; |
339 | size_t M_gws_d[3], M_lws_d[3]; |
340 | bool is_fused; |
341 | |
342 | data_type_t src_data_type; |
343 | data_type_t weights_data_type; |
344 | data_type_t bias_data_type; |
345 | data_type_t dst_data_type; |
346 | data_type_t acc_data_type; |
347 | |
348 | memory_desc_info_t src_md_info; |
349 | memory_desc_info_t wei_md_info; |
350 | memory_desc_info_t dst_md_info; |
351 | }; |
352 | |
353 | // Pooling |
354 | struct pool_conf_t { |
355 | int ndims; |
356 | int mb, c; |
357 | int mb_padded; |
358 | int c_padded; |
359 | int id, ih, iw, od, oh, ow; |
360 | int stride_d, stride_h, stride_w; |
361 | int kd, kh, kw; |
362 | int dd, dh, dw; |
363 | int f_pad, t_pad, l_pad; |
364 | data_type_t src_dt; |
365 | data_type_t dst_dt; |
366 | alg_kind_t alg; |
367 | bool is_plain; |
368 | bool is_training, is_backward; |
369 | bool use_mb_c_block, use_only_c_block; |
370 | bool unroll_mb = false; |
371 | int chunks_per_c_block, chunks_per_mb_block; |
372 | int vect_dt_n; |
373 | int nvect; |
374 | compute::dispatch_t dispatch; |
375 | int sub_group_size; |
376 | int global_pool_spatial_chunk; |
377 | int num_batches = 1; |
378 | int mb_block_size = 16; |
379 | |
380 | attr_info_t attr_info; |
381 | memory_desc_info_t src_md_info; |
382 | memory_desc_info_t dst_md_info; |
383 | }; |
384 | |
385 | // Prelu |
386 | struct prelu_conf_t { |
387 | bool is_forward; |
388 | bool reduce_diff_weights; |
389 | compute::dispatch_t dispatch; |
390 | |
391 | attr_info_t attr_info; |
392 | memory_desc_info_t src_md_info; |
393 | memory_desc_info_t wei_md_info; |
394 | memory_desc_info_t dst_md_info; |
395 | memory_desc_info_t diff_src_md_info; |
396 | memory_desc_info_t diff_wei_md_info; |
397 | }; |
398 | |
399 | // Inner Product |
400 | struct inner_product_conf_t { |
401 | int ndims; |
402 | int src_ndims, wei_ndims, dst_ndims; |
403 | int mb, oc, ic, ic_total; |
404 | int id, ih, iw, od, oh, ow; |
405 | int kd, kh, kw; |
406 | bool with_bias, has_spatial; |
407 | bool is_forward, is_backward_data, is_backward_weights; |
408 | compute::dispatch_t dispatch; |
409 | bool reorder_dst = false; |
410 | |
411 | data_type_t src_dt; |
412 | data_type_t wei_dt; |
413 | data_type_t bia_dt; |
414 | data_type_t dst_dt; |
415 | data_type_t acc_dt; |
416 | |
417 | attr_info_t attr_info; |
418 | }; |
419 | |
420 | // RNN |
421 | struct rnn_conf_t { |
422 | int cell_kind; |
423 | int activation_kind; |
424 | int direction_kind; |
425 | bool with_bias; |
426 | bool with_src_iter; |
427 | bool with_src_iter_c; |
428 | bool with_dst_iter; |
429 | bool with_dst_iter_c; |
430 | bool is_lbr; |
431 | bool is_vanilla_gru; |
432 | bool is_fwd; |
433 | bool copy_bias; |
434 | bool is_int8; |
435 | bool is_testmode; |
436 | bool is_training; |
437 | data_type_t src_dt; |
438 | data_type_t wei_dt; |
439 | data_type_t bia_dt; |
440 | data_type_t dst_dt; |
441 | data_type_t acc_dt; |
442 | data_type_t aux_dt; |
443 | data_type_t input_dt; |
444 | data_type_t output_dt; |
445 | data_type_t diff_dt; |
446 | |
447 | int n_layer; |
448 | int n_dir; |
449 | int n_iter; |
450 | int n_iter_scratch_gates; |
451 | int n_gates; |
452 | int n_bias; |
453 | int n_states; |
454 | int n_weights_input; |
455 | int n_weights_state; |
456 | int batch; |
457 | int slc; |
458 | int sic; |
459 | int dhc; |
460 | int dlc; |
461 | int wic; |
462 | int arch_ld; |
463 | int n_parts_weights_iter, n_parts_weights_layer; |
464 | int src_layer_ndims; |
465 | int src_iter_ndims; |
466 | int src_iter_c_ndims; |
467 | int weights_layer_ndims; |
468 | int weights_iter_ndims; |
469 | int dst_layer_ndims; |
470 | int dst_iter_ndims; |
471 | int dst_iter_c_ndims; |
472 | int bias_ndims; |
473 | int diff_src_layer_ndims; |
474 | int diff_src_iter_ndims; |
475 | int diff_src_iter_c_ndims; |
476 | int diff_weights_layer_ndims; |
477 | int diff_weights_iter_ndims; |
478 | int diff_dst_layer_ndims; |
479 | int diff_dst_iter_ndims; |
480 | int diff_dst_iter_c_ndims; |
481 | int diff_bias_ndims; |
482 | int states_ws_ld, gates_ws_ld, scratch_diff_states_ld, scratch_gates_ld; |
483 | |
484 | int wei_qparam_mask; |
485 | |
486 | size_t ws_gates_offset; |
487 | size_t ws_states_offset; |
488 | size_t ws_grid_comp_offset; |
489 | size_t ws_h_state_offset; |
490 | size_t ws_c_state_offset; |
491 | size_t ws_bias_offset; |
492 | size_t scratchpad_size; |
493 | size_t scratch_dhG1_offset; |
494 | size_t scratch_gates_offset; |
495 | size_t scratch_cell_offset; |
496 | size_t scratch_diff_states_offset; |
497 | size_t workspace_size; |
498 | }; |
499 | |
500 | struct rnn_reorder_conf_t { |
501 | bool do_reorder, with_group, has_padding; |
502 | bool with_sum_ab, with_sum_a; |
503 | bool use_ref_impl; |
504 | int ndims; |
505 | size_t nelems; |
506 | compute::dispatch_t dispatch; |
507 | int block[3]; |
508 | int sub_group_size; |
509 | int mask; |
510 | size_t scales_count; |
511 | }; |
512 | |
513 | // Batch Normalization |
514 | struct bnorm_conf_t { |
515 | data_type_t data_type; |
516 | |
517 | int ndims; |
518 | int mb, ic, mb_block, ic_block; |
519 | int reduce_dim_idx, reduce_dim; |
520 | int id, ih, iw; |
521 | int nn, sp, sp_tail, vect_size; |
522 | int stat_sp_nblocks, stat_sp_tail, stat_sp_block; |
523 | int reduce_stat_nblocks; |
524 | bool with_relu, use_16mb_unroll, use_nhwc; |
525 | int stat_ic; |
526 | bool is_forward, is_backward; |
527 | bool use_scale, use_shift, save_stats, is_training; |
528 | bool calculate_stats, calculate_diff_stats; |
529 | bool fuse_norm_relu, fuse_norm_add_relu; |
530 | bool diff_scale, diff_shift; |
531 | float relu_negative_slope, eps; |
532 | int sub_group_size; |
533 | bool vectorize_calc_stats; |
534 | bool skip_reduce_stat; |
535 | bool use_stats_one_pass; |
536 | bool nhwc_optimized; |
537 | int calc_stat_ic; |
538 | bool use_fused_atomics_reduction; |
539 | |
540 | compute::dispatch_t dispatch_calc_stat; |
541 | compute::dispatch_t dispatch_reduce_stat; |
542 | compute::dispatch_t dispatch; |
543 | compute::dispatch_t dispatch_reduce_aux; |
544 | }; |
545 | |
546 | // Layer Normalization |
547 | struct lnorm_conf_t { |
548 | data_type_t data_type; |
549 | |
550 | bool is_fwd; |
551 | int ndims; |
552 | int norm_axis; |
553 | |
554 | memory_desc_info_t src_md_info; |
555 | memory_desc_info_t dst_md_info; |
556 | memory_desc_info_t stat_md_info; |
557 | |
558 | bool use_scale; |
559 | bool use_shift; |
560 | bool calculate_stats; |
561 | bool save_stats; |
562 | bool vectorize_calc_stats; |
563 | bool vectorize_bwd; |
564 | bool vectorize_bwd_scaleshift; |
565 | float eps; |
566 | int sub_group_size; |
567 | int vect_dt_n; |
568 | int shift_off; |
569 | int n_chunk_size; |
570 | int n_chunks; |
571 | int vector_size_scaleshift; |
572 | |
573 | compute::dispatch_t dispatch_scaleshift; |
574 | compute::dispatch_t dispatch_scaleshift_finalize; |
575 | compute::dispatch_t dispatch; |
576 | }; |
577 | |
578 | // Binary |
579 | struct binary_conf_t { |
580 | int ndims, nvect; |
581 | bool use_unroll_16b, src0_unroll_16b; |
582 | bool is_plain_layout; |
583 | bool plain_to_ABcd4a4b; |
584 | bool isXa16b; |
585 | data_type_t src0_data_type; |
586 | data_type_t src1_data_type; |
587 | data_type_t dst_data_type; |
588 | bool is_mul; |
589 | bool is_add; |
590 | bool is_max; |
591 | bool is_min; |
592 | bool is_div; |
593 | bool is_sub; |
594 | bool is_ge; |
595 | bool is_gt; |
596 | bool is_le; |
597 | bool is_lt; |
598 | bool is_eq; |
599 | bool is_ne; |
600 | bool is_tensor_op; |
601 | compute::dispatch_t dispatch; |
602 | int mb_block; |
603 | int dim0[MAX_NDIMS]; |
604 | int src0_bcast_dims[MAX_NDIMS]; |
605 | int src1_bcast_dims[MAX_NDIMS]; |
606 | bool is_dense; |
607 | bool is_same_md; |
608 | bool same_src_dt; |
609 | bool with_binary_post_op; |
610 | |
611 | memory_desc_info_t src0_md_info; |
612 | memory_desc_info_t src1_md_info; |
613 | memory_desc_info_t dst_md_info; |
614 | |
615 | attr_info_t attr_info; |
616 | }; |
617 | |
618 | // Reduction |
619 | struct reduction_phase_t { |
620 | data_type_t src_type, dst_type; |
621 | compute::nd_range_t nd_range; |
622 | compute::kernel_t kernel; |
623 | int reduction_size, initial_size, final_size; |
624 | int num_reduction_chunks; |
625 | bool is_final, is_first; |
626 | }; |
627 | |
628 | struct reduction_conf_t { |
629 | // Used by reference implementation |
630 | alg_kind_t alg; |
631 | int ndims, power, div; |
632 | float eps; |
633 | dim_t src_dims[MAX_NDIMS], reduce_dims[MAX_NDIMS], dst_dims[MAX_NDIMS]; |
634 | bool is_reduction_dim[MAX_NDIMS]; |
635 | data_type_t src_type, dst_type; |
636 | memory_desc_info_t src_md_info, dst_md_info; |
637 | compute::dispatch_t dispatch; |
638 | offsets_t off; |
639 | attr_info_t attr_info; |
640 | |
641 | // Used by gen9 implementation |
642 | int initial_hwd_dim, initial_hwd_chunk_size; |
643 | int final_hwd_dim, final_hwd_chunk_size; |
644 | int initial_c_chunks, final_c_dim, final_c_chunk_size; |
645 | int initial_n_chunk_size, initial_n_chunks; |
646 | int final_n_dim, final_n_chunk_size; |
647 | bool skip_final_phase; |
648 | int c_block_size, n_block_size; |
649 | int vector_size; |
650 | int sub_group_size; |
651 | compute::dispatch_t finalize_dispatch; |
652 | |
653 | // Used by combined implementation |
654 | int outer_dim_size, inner_dim_size, gws_inner_dim_size; |
655 | std::vector<reduction_phase_t> phases; |
656 | int inner_dim_per_sg; |
657 | size_t sp_size[2]; |
658 | }; |
659 | |
660 | // Reorder |
661 | enum reorder_kernel_t { |
662 | none, |
663 | dense_vector, |
664 | unroll_16b, |
665 | unroll_16b16c, |
666 | unroll_16a16b, |
667 | plain_to_ABcd84a42b, |
668 | vectorize_last_dim, |
669 | plain_to_ABxx8ayb, |
670 | plain_xFxE_to_abcdef, |
671 | transpose8x8, |
672 | transpose16x16, |
673 | local8x8, |
674 | local16x16, |
675 | reorder_nchw, |
676 | unaligned_sizes, |
677 | reorder_alt, |
678 | vectorize_groups, |
679 | pad_innermost, |
680 | xb_to_xab_xba |
681 | }; |
682 | |
683 | // Resampling |
684 | struct resampling_conf_t { |
685 | dim_t ndims; |
686 | offsets_t off; |
687 | dim_t MB, C; |
688 | dim_t ID, IH, IW; |
689 | dim_t OD, OH, OW; |
690 | float FD, FH, FW; |
691 | dim_t vect_size; |
692 | dims_t padded_strides; |
693 | size_t lws[3], gws[3]; |
694 | int sub_group_size; |
695 | dim_t padded_c; |
696 | attr_info_t attr_info; |
697 | compute::dispatch_t dispatch; |
698 | }; |
699 | |
700 | struct block_desc_t { |
701 | int dim_idx; |
702 | int blk_size; |
703 | int step_size; |
704 | }; |
705 | |
706 | #define LOOP_NEST_LEVEL 4 |
707 | struct vectorize_last_dim_t { |
708 | int vector_dim; |
709 | int rescale_coeff; |
710 | // composition of data within 16-item packet |
711 | block_desc_t src_vct[LOOP_NEST_LEVEL]; |
712 | block_desc_t dst_vct[LOOP_NEST_LEVEL]; |
713 | // dimensions to loop over when accessing packets defined above |
714 | block_desc_t src_blk[LOOP_NEST_LEVEL]; |
715 | block_desc_t dst_blk[LOOP_NEST_LEVEL]; |
716 | int src_blk_limits[MAX_NDIMS]; |
717 | int dst_blk_limits[MAX_NDIMS]; |
718 | int src_vect_limit; |
719 | int dst_vect_limit; |
720 | }; |
721 | |
722 | struct vectorize_group_t { |
723 | int vector_dim; |
724 | int src_loop_dim; |
725 | int dst_loop_dim; |
726 | int group_size; |
727 | int innermost_size; |
728 | }; |
729 | |
730 | struct xb_to_xab_xba_t { |
731 | int vd; |
732 | int blk_size; |
733 | int src_blk_dim; |
734 | int src_blk_coeff; |
735 | int dst_blk_dim; |
736 | int dst_blk_coeff; |
737 | }; |
738 | |
739 | union reorder_implementation { |
740 | vectorize_group_t vg; |
741 | xb_to_xab_xba_t ab; |
742 | vectorize_last_dim_t vld; |
743 | }; |
744 | |
745 | class scales_query_t { |
746 | public: |
747 | bool has_default_values() const { return scales_.has_default_values(); } |
748 | int get_mask() const { return scales_.mask_; } |
749 | size_t get_count() const { return count_; } |
750 | memory_storage_t &get_scales(const exec_ctx_t &ctx) const { |
751 | return CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | arg_); |
752 | } |
753 | |
754 | scales_query_t() = default; |
755 | scales_query_t(const primitive_attr_t *attr, const memory_desc_wrapper &mdw, |
756 | int arg) |
757 | : arg_(arg) { |
758 | scales_ = attr->scales_.get(arg); |
759 | count_ = get_attr_oscales_count(scales_.mask_, mdw); |
760 | } |
761 | |
762 | private: |
763 | runtime_scales_t scales_; |
764 | dim_t count_ = 0; |
765 | int arg_ = 0; |
766 | }; |
767 | |
768 | class zero_points_query_t { |
769 | public: |
770 | bool has_default_values() const { return zps_.has_default_values(arg_); } |
771 | int get_mask() const { |
772 | int mask; |
773 | zps_.get(arg_, &mask); |
774 | return mask; |
775 | } |
776 | size_t get_count() const { return count_; } |
777 | memory_storage_t &get_zero_points(const exec_ctx_t &ctx) const { |
778 | return CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | arg_); |
779 | } |
780 | |
781 | zero_points_query_t() = default; |
782 | zero_points_query_t(const primitive_attr_t *attr, |
783 | const memory_desc_wrapper &mdw, int arg) |
784 | : arg_(arg) { |
785 | zps_ = attr->zero_points_; |
786 | int mask; |
787 | zps_.get(arg, &mask); |
788 | count_ = get_attr_oscales_count(mask, mdw); |
789 | } |
790 | |
791 | private: |
792 | zero_points_t zps_; |
793 | dim_t count_; |
794 | int arg_; |
795 | }; |
796 | |
797 | struct quantization_t { |
798 | public: |
799 | bool with_scale() const { return !scale_.has_default_values(); } |
800 | int scale_mask() const { return scale_.get_mask(); } |
801 | size_t num_scales() const { return scale_.get_count(); } |
802 | memory_storage_t &scales(const exec_ctx_t &ctx) const { |
803 | return scale_.get_scales(ctx); |
804 | } |
805 | |
806 | bool with_zp() const { return !zp_.has_default_values(); } |
807 | int zp_mask() const { return zp_.get_mask(); } |
808 | size_t num_zps() const { return zp_.get_count(); } |
809 | memory_storage_t &zero_points(const exec_ctx_t &ctx) const { |
810 | return zp_.get_zero_points(ctx); |
811 | } |
812 | |
813 | void define_macros( |
814 | compute::kernel_ctx_t &kernel_ctx, const std::string &name) const { |
815 | if (with_scale()) { |
816 | kernel_ctx.define_int("WITH_" + name + "_SCALE" , 1); |
817 | kernel_ctx.define_int(name + "_SCALE_MASK" , scale_mask()); |
818 | kernel_ctx.define_int(name + "_NUM_SCALES" , num_scales()); |
819 | } |
820 | |
821 | if (with_zp()) { |
822 | kernel_ctx.define_int("WITH_" + name + "_ZPOINT" , 1); |
823 | kernel_ctx.define_int(name + "_ZPOINT_MASK" , zp_mask()); |
824 | kernel_ctx.define_int(name + "_NUM_ZPOINTS" , num_zps()); |
825 | } |
826 | } |
827 | |
828 | quantization_t(const primitive_attr_t *attr, const memory_desc_wrapper &mdw, |
829 | int arg) |
830 | : scale_(attr, mdw, arg), zp_(attr, mdw, arg) {} |
831 | quantization_t() = default; |
832 | |
833 | private: |
834 | scales_query_t scale_; |
835 | zero_points_query_t zp_; |
836 | }; |
837 | |
838 | struct sum_quantization_t { |
839 | public: |
840 | bool with_scale() const { return scale_ != 0; } |
841 | int scale_mask() const { return 0; } |
842 | size_t num_scales() const { return (size_t)(with_scale()); } |
843 | float scales() const { return scale_; } |
844 | |
845 | bool with_zp() const { return zp_ != 0; } |
846 | int zp_mask() const { return 0; } |
847 | size_t num_zps() const { return (size_t)(with_zp()); } |
848 | int zero_points() const { return zp_; } |
849 | |
850 | void define_macros( |
851 | compute::kernel_ctx_t &kernel_ctx, const std::string &name) const { |
852 | if (with_scale()) kernel_ctx.define_int("WITH_" + name + "_SCALE" , 1); |
853 | if (with_zp()) kernel_ctx.define_int("WITH_" + name + "_ZPOINT" , 1); |
854 | } |
855 | |
856 | sum_quantization_t(const primitive_attr_t *attr) { |
857 | const auto &post_ops = attr->post_ops_; |
858 | const int sum_idx = post_ops.find(primitive_kind::sum); |
859 | if (sum_idx != -1) { |
860 | const auto &sum = post_ops.entry_[sum_idx].sum; |
861 | scale_ = sum.scale; |
862 | zp_ = sum.zero_point; |
863 | } |
864 | } |
865 | sum_quantization_t() = default; |
866 | |
867 | private: |
868 | float scale_ = 0; |
869 | int zp_ = 0; |
870 | }; |
871 | |
872 | struct reorder_conf_t { |
873 | bool has_padding; |
874 | |
875 | quantization_t src_quant, dst_quant; |
876 | sum_quantization_t sum_quant; |
877 | |
878 | reorder_kernel_t implementation; |
879 | int ndims; |
880 | size_t nelems; |
881 | |
882 | compute::dispatch_t dispatch; |
883 | |
884 | int sub_group_size; |
885 | memory_desc_info_t src_md_info; |
886 | memory_desc_info_t dst_md_info; |
887 | |
888 | reorder_implementation aux_data; |
889 | }; |
890 | |
891 | // Concat |
892 | struct concat_conf_t { |
893 | dim_t dst_extern_dim_size; |
894 | dim_t src_extern_dim_sizes[64]; |
895 | dim_t offset[64]; |
896 | dim_t inner_axis; |
897 | dim_t dst_offset0; |
898 | int block; |
899 | int n; |
900 | int simd; |
901 | int data_type_size; |
902 | size_t gws_d[3], lws_d[3]; |
903 | |
904 | data_type_t src_type, dst_type; |
905 | compute::dispatch_t dispatch; |
906 | int ndims; |
907 | memory_desc_info_t src_md_infos[64]; |
908 | memory_desc_info_t dst_md_info; |
909 | int concat_axis; |
910 | int sub_group_size; |
911 | int iter_dim_idx, iter_dim_chunk; |
912 | }; |
913 | |
914 | // Elementwise |
915 | struct eltwise_conf_t { |
916 | int ndims; |
917 | int vector_size; |
918 | bool with_zero_padding; |
919 | data_type_t data_type; |
920 | alg_kind_t alg; |
921 | bool is_forward; |
922 | int work_group_size; |
923 | int sub_group_size; |
924 | compute::dispatch_t dispatch; |
925 | memory_desc_info_t data_md_info; |
926 | memory_desc_info_t data_diff_md_info; |
927 | |
928 | attr_info_t attr_info; |
929 | }; |
930 | |
931 | // Shuffle |
932 | struct shuffle_conf_t { |
933 | data_type_t data_type; |
934 | int axis; |
935 | int transpose_row; |
936 | int transpose_col; |
937 | compute::dispatch_t dispatch; |
938 | memory_desc_info_t src_md_info; |
939 | memory_desc_info_t dst_md_info; |
940 | }; |
941 | |
942 | inline void set_default_pool_conf(pool_conf_t &conf, const pooling_desc_t &desc, |
943 | const memory_desc_t &src_md, const memory_desc_t &dst_md, |
944 | const primitive_attr_t &attr) { |
945 | const memory_desc_wrapper src_mdw(src_md); |
946 | const memory_desc_wrapper dst_mdw(dst_md); |
947 | |
948 | const auto &src_dims = src_mdw.dims(); |
949 | const auto &dst_dims = dst_mdw.dims(); |
950 | |
951 | int ndims = src_mdw.ndims(); |
952 | conf.ndims = ndims; |
953 | |
954 | conf.mb = src_dims[0]; |
955 | |
956 | conf.c = src_dims[1]; |
957 | conf.mb_padded = src_mdw.padded_dims()[0]; |
958 | conf.c_padded = src_mdw.padded_dims()[1]; |
959 | conf.id = (ndims == 5) ? src_dims[2] : 1; |
960 | conf.ih = (ndims == 3) ? 1 : src_dims[ndims - 2]; |
961 | conf.iw = src_dims[ndims - 1]; |
962 | conf.od = (ndims == 5) ? dst_dims[2] : 1; |
963 | conf.oh = (ndims == 3) ? 1 : dst_dims[ndims - 2]; |
964 | conf.ow = dst_dims[ndims - 1]; |
965 | |
966 | conf.stride_d = (ndims == 5) ? desc.strides[0] : 1; |
967 | conf.stride_h = (ndims == 3) ? 1 : desc.strides[ndims - 4]; |
968 | conf.stride_w = desc.strides[ndims - 3]; |
969 | conf.kd = (ndims == 5) ? desc.kernel[0] : 1; |
970 | conf.kh = (ndims == 3) ? 1 : desc.kernel[ndims - 4]; |
971 | conf.kw = desc.kernel[ndims - 3]; |
972 | |
973 | conf.dd = (ndims == 5) ? desc.dilation[0] : 0; |
974 | conf.dh = (ndims == 3) ? 0 : desc.dilation[ndims - 4]; |
975 | conf.dw = desc.dilation[ndims - 3]; |
976 | |
977 | conf.f_pad = (ndims == 5) ? desc.padding[0][0] : 0; |
978 | conf.t_pad = (ndims == 3) ? 0 : desc.padding[0][ndims - 4]; |
979 | conf.l_pad = desc.padding[0][ndims - 3]; |
980 | |
981 | conf.alg = desc.alg_kind; |
982 | |
983 | conf.src_dt = src_mdw.data_type(); |
984 | conf.dst_dt = dst_mdw.data_type(); |
985 | |
986 | conf.src_md_info = memory_desc_info_t::create(src_mdw); |
987 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
988 | |
989 | conf.is_training = desc.prop_kind == prop_kind::forward_training; |
990 | conf.is_backward = desc.prop_kind == prop_kind::backward_data; |
991 | |
992 | conf.attr_info = attr_info_t::create(&attr); |
993 | } |
994 | |
995 | inline void set_default_conf(conv_conf_t &conf, const convolution_desc_t &cd, |
996 | const memory_desc_t &src_md, const memory_desc_t &weights_md, |
997 | const memory_desc_t &dst_md, const memory_desc_t &bias_md, |
998 | const primitive_attr_t &attr) { |
999 | |
1000 | const memory_desc_wrapper src_mdw(&src_md); |
1001 | const memory_desc_wrapper weights_mdw(&weights_md); |
1002 | const memory_desc_wrapper dst_mdw(&dst_md); |
1003 | const memory_desc_wrapper bias_mdw(&bias_md); |
1004 | |
1005 | const bool with_groups = weights_mdw.ndims() == src_mdw.ndims() + 1; |
1006 | int ndims = src_mdw.ndims(); |
1007 | |
1008 | conf = utils::zero<decltype(conf)>(); |
1009 | conf.with_groups = with_groups; |
1010 | conf.ndims = ndims; |
1011 | conf.prop_kind = cd.prop_kind; |
1012 | conf.ngroups = with_groups ? weights_mdw.dims()[0] : 1; |
1013 | conf.mb = src_mdw.dims()[0]; |
1014 | conf.oc_without_padding = dst_mdw.dims()[1] / conf.ngroups; |
1015 | conf.ic_without_padding = src_mdw.dims()[1] / conf.ngroups; |
1016 | conf.id = (ndims == 5) ? src_mdw.dims()[2] : 1; |
1017 | conf.ih = (ndims == 3) ? 1 : src_mdw.dims()[ndims - 2]; |
1018 | conf.iw = src_mdw.dims()[ndims - 1]; |
1019 | conf.od = (ndims == 5) ? dst_mdw.dims()[2] : 1; |
1020 | conf.oh = (ndims == 3) ? 1 : dst_mdw.dims()[ndims - 2]; |
1021 | conf.ow = dst_mdw.dims()[ndims - 1]; |
1022 | conf.kd = (ndims == 5) ? weights_mdw.dims()[with_groups + 2] : 1; |
1023 | conf.kh = (ndims == 3) ? 1 : weights_mdw.dims()[with_groups + ndims - 2]; |
1024 | conf.kw = weights_mdw.dims()[with_groups + ndims - 1]; |
1025 | |
1026 | conf.is_depthwise = conf.with_groups && conf.oc_without_padding == 1 |
1027 | && conf.ic_without_padding == 1; |
1028 | conf.oc = dst_mdw.dims()[1] / conf.ngroups; |
1029 | conf.ic = src_mdw.dims()[1] / conf.ngroups; |
1030 | |
1031 | conf.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1032 | conf.back_pad = (ndims == 5) ? cd.padding[1][0] : 0; |
1033 | conf.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1034 | conf.b_pad = (ndims == 3) ? 0 : cd.padding[1][ndims - 4]; |
1035 | conf.l_pad = cd.padding[0][ndims - 3]; |
1036 | conf.r_pad = cd.padding[1][ndims - 3]; |
1037 | conf.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1038 | conf.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1039 | conf.stride_w = cd.strides[ndims - 3]; |
1040 | conf.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1041 | conf.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
1042 | conf.dilate_w = cd.dilates[ndims - 3]; |
1043 | |
1044 | conf.with_bias = bias_mdw.format_kind() != format_kind::undef; |
1045 | |
1046 | conf.src_data_type = src_mdw.data_type(); |
1047 | conf.weights_data_type = weights_mdw.data_type(); |
1048 | conf.dst_data_type = dst_mdw.data_type(); |
1049 | |
1050 | conf.acc_data_type = cd.accum_data_type; |
1051 | conf.bias_data_type |
1052 | = conf.with_bias ? bias_mdw.data_type() : data_type::f32; |
1053 | |
1054 | if (!src_mdw.format_any()) |
1055 | conf.src_md_info = memory_desc_info_t::create(src_mdw); |
1056 | if (!weights_mdw.format_any()) |
1057 | conf.wei_md_info = memory_desc_info_t::create(weights_mdw); |
1058 | if (!dst_mdw.format_any()) |
1059 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
1060 | |
1061 | conf.attr_info = attr_info_t::create(&attr); |
1062 | } |
1063 | |
1064 | inline void set_offsets(compute::kernel_ctx_t &kernel_ctx, |
1065 | const memory_desc_wrapper &md, const char *str) { |
1066 | dim_t block_dims[DNNL_MAX_NDIMS]; |
1067 | dim_t strides_compat[2][DNNL_MAX_NDIMS]; |
1068 | |
1069 | md.compute_blocks(block_dims); |
1070 | md.compute_strides_compat(strides_compat); |
1071 | |
1072 | for (int d = 0; d < MAX_NDIMS; ++d) { |
1073 | const int block = block_dims[d]; |
1074 | |
1075 | kernel_ctx.define_int( |
1076 | utils::format("%s_B%d" , str, d), (d < md.ndims()) ? block : 1); |
1077 | kernel_ctx.define_int(utils::format("%s_S%d" , str, d), |
1078 | (d < md.ndims()) ? strides_compat[0][d] : 0); |
1079 | kernel_ctx.define_int(utils::format("%s_SB%d" , str, d), |
1080 | (d < md.ndims()) ? strides_compat[1][d] : 0); |
1081 | } |
1082 | |
1083 | kernel_ctx.define_int(utils::format("%s_OFFSET_PAD" , str), md.md_->offset0); |
1084 | } |
1085 | |
1086 | inline void set_offsets(const memory_desc_wrapper &md, int offs[4][MAX_NDIMS]) { |
1087 | dim_t block_dims[DNNL_MAX_NDIMS]; |
1088 | dim_t strides_compat[2][DNNL_MAX_NDIMS]; |
1089 | |
1090 | md.compute_blocks(block_dims); |
1091 | md.compute_strides_compat(strides_compat); |
1092 | const dims_t &dims = md.dims(); |
1093 | |
1094 | for (int d = 0; d < md.ndims(); ++d) { |
1095 | const int block = block_dims[d]; |
1096 | |
1097 | offs[0][d] = block; |
1098 | offs[1][d] = strides_compat[0][d]; |
1099 | offs[2][d] = strides_compat[1][d]; |
1100 | offs[3][d] = dims[d]; |
1101 | } |
1102 | } |
1103 | |
1104 | inline void def_offsets(const int offs[4][MAX_NDIMS], |
1105 | compute::kernel_ctx_t &kernel_ctx, const char *str, const int ndims) { |
1106 | |
1107 | for (int d = 0; d < MAX_NDIMS; d++) { |
1108 | kernel_ctx.define_int( |
1109 | utils::format("%s_B%d" , str, d), (d < ndims) ? offs[0][d] : 1); |
1110 | kernel_ctx.define_int( |
1111 | utils::format("%s_S%d" , str, d), (d < ndims) ? offs[1][d] : 0); |
1112 | kernel_ctx.define_int( |
1113 | utils::format("%s_SB%d" , str, d), (d < ndims) ? offs[2][d] : 0); |
1114 | kernel_ctx.define_int( |
1115 | utils::format("%s_D%d" , str, d), (d < ndims) ? offs[3][d] : 0); |
1116 | } |
1117 | } |
1118 | |
1119 | inline void def_data_type( |
1120 | compute::kernel_ctx_t &kernel_ctx, data_type_t dt, const char *str) { |
1121 | switch (dt) { |
1122 | case data_type::bf16: |
1123 | kernel_ctx.add_option( |
1124 | utils::format("-D%s_DATA_T=ushort -D%s_DT_BF16" , str, str)); |
1125 | break; |
1126 | case data_type::f16: |
1127 | kernel_ctx.add_option( |
1128 | utils::format("-D%s_DATA_T=half -D%s_DT_F16" , str, str)); |
1129 | break; |
1130 | case data_type::f32: |
1131 | kernel_ctx.add_option( |
1132 | utils::format("-D%s_DATA_T=float -D%s_DT_F32" , str, str)); |
1133 | break; |
1134 | case data_type::f64: |
1135 | kernel_ctx.add_option( |
1136 | utils::format("-D%s_DATA_T=double -D%s_DT_F64" , str, str)); |
1137 | break; |
1138 | case data_type::s8: |
1139 | kernel_ctx.add_option( |
1140 | utils::format("-D%s_DATA_T=char -D%s_DT_S8" , str, str)); |
1141 | break; |
1142 | case data_type::u8: |
1143 | kernel_ctx.add_option( |
1144 | utils::format("-D%s_DATA_T=uchar -D%s_DT_U8" , str, str)); |
1145 | break; |
1146 | case data_type::s32: |
1147 | kernel_ctx.add_option( |
1148 | utils::format("-D%s_DATA_T=int -D%s_DT_S32" , str, str)); |
1149 | break; |
1150 | default: assert(!"unsupported data type" ); break; |
1151 | } |
1152 | } |
1153 | |
1154 | inline void def_memory_desc_info(compute::kernel_ctx_t &kernel_ctx, |
1155 | const memory_desc_info_t &md_info, const char *prefix) { |
1156 | def_data_type(kernel_ctx, md_info.data_type, prefix); |
1157 | |
1158 | kernel_ctx.define_int(utils::format("%s_OFFSET0" , prefix), md_info.offset0); |
1159 | kernel_ctx.define_int(utils::format("%s_NDIMS" , prefix), md_info.ndims); |
1160 | |
1161 | kernel_ctx.define_int(utils::format("%s_NLEVELS" , prefix), md_info.nlevels); |
1162 | |
1163 | for (int d = 0; d < MAX_NDIMS; ++d) { |
1164 | int dim = (d < md_info.ndims) ? md_info.dims[d] : 1; |
1165 | int padded_dim = (d < md_info.ndims) ? md_info.padded_dims[d] : 1; |
1166 | kernel_ctx.define_int(utils::format("%s_D%d" , prefix, d), dim); |
1167 | kernel_ctx.define_int(utils::format("%s_PD%d" , prefix, d), padded_dim); |
1168 | |
1169 | for (int l = 0; l < md_info.nlevels + 1; ++l) { |
1170 | int block = (d < md_info.ndims) ? md_info.blocks[d][l] : 1; |
1171 | int stride = (d < md_info.ndims) ? md_info.strides[d][l] : 0; |
1172 | kernel_ctx.define_int( |
1173 | utils::format("%s_B%d_%d" , prefix, d, l), block); |
1174 | kernel_ctx.define_int( |
1175 | utils::format("%s_S%d_%d" , prefix, d, l), stride); |
1176 | } |
1177 | } |
1178 | } |
1179 | |
1180 | inline void def_binary_alg_kinds(compute::kernel_ctx_t &kernel_ctx) { |
1181 | kernel_ctx.define_int("BINARY_ADD" , alg_kind::binary_add); |
1182 | kernel_ctx.define_int("BINARY_MUL" , alg_kind::binary_mul); |
1183 | kernel_ctx.define_int("BINARY_MIN" , alg_kind::binary_min); |
1184 | kernel_ctx.define_int("BINARY_MAX" , alg_kind::binary_max); |
1185 | kernel_ctx.define_int("BINARY_DIV" , alg_kind::binary_div); |
1186 | kernel_ctx.define_int("BINARY_SUB" , alg_kind::binary_sub); |
1187 | kernel_ctx.define_int("BINARY_GE" , alg_kind::binary_ge); |
1188 | kernel_ctx.define_int("BINARY_GT" , alg_kind::binary_gt); |
1189 | kernel_ctx.define_int("BINARY_LE" , alg_kind::binary_le); |
1190 | kernel_ctx.define_int("BINARY_LT" , alg_kind::binary_lt); |
1191 | kernel_ctx.define_int("BINARY_EQ" , alg_kind::binary_eq); |
1192 | kernel_ctx.define_int("BINARY_NE" , alg_kind::binary_ne); |
1193 | } |
1194 | |
1195 | inline void def_eltwise_alg_kinds(compute::kernel_ctx_t &kernel_ctx) { |
1196 | kernel_ctx.define_int("RELU" , alg_kind::eltwise_relu); |
1197 | kernel_ctx.define_int("LINEAR" , alg_kind::eltwise_linear); |
1198 | kernel_ctx.define_int("SOFT_RELU" , alg_kind::eltwise_soft_relu); |
1199 | kernel_ctx.define_int("MISH" , alg_kind::eltwise_mish); |
1200 | kernel_ctx.define_int("LOGISTIC" , alg_kind::eltwise_logistic); |
1201 | kernel_ctx.define_int("TANH" , alg_kind::eltwise_tanh); |
1202 | kernel_ctx.define_int("ELU" , alg_kind::eltwise_elu); |
1203 | kernel_ctx.define_int("SQUARE" , alg_kind::eltwise_square); |
1204 | kernel_ctx.define_int("SQRT" , alg_kind::eltwise_sqrt); |
1205 | kernel_ctx.define_int("ABS" , alg_kind::eltwise_abs); |
1206 | kernel_ctx.define_int("EXP" , alg_kind::eltwise_exp); |
1207 | kernel_ctx.define_int("GELU_TANH" , alg_kind::eltwise_gelu_tanh); |
1208 | kernel_ctx.define_int("SWISH" , alg_kind::eltwise_swish); |
1209 | kernel_ctx.define_int("LOG" , alg_kind::eltwise_log); |
1210 | kernel_ctx.define_int("CLIP" , alg_kind::eltwise_clip); |
1211 | kernel_ctx.define_int("CLIP_V2" , alg_kind::eltwise_clip_v2); |
1212 | kernel_ctx.define_int("POW" , alg_kind::eltwise_pow); |
1213 | kernel_ctx.define_int("GELU_ERF" , alg_kind::eltwise_gelu_erf); |
1214 | kernel_ctx.define_int("ROUND" , alg_kind::eltwise_round); |
1215 | kernel_ctx.define_int("HARDSWISH" , alg_kind::eltwise_hardswish); |
1216 | kernel_ctx.define_int("HARDSIGMOID" , alg_kind::eltwise_hardsigmoid); |
1217 | |
1218 | kernel_ctx.define_int("RELU_DST" , alg_kind::eltwise_relu_use_dst_for_bwd); |
1219 | kernel_ctx.define_int( |
1220 | "LOGISTIC_DST" , alg_kind::eltwise_logistic_use_dst_for_bwd); |
1221 | kernel_ctx.define_int("TANH_DST" , alg_kind::eltwise_tanh_use_dst_for_bwd); |
1222 | kernel_ctx.define_int("ELU_DST" , alg_kind::eltwise_elu_use_dst_for_bwd); |
1223 | kernel_ctx.define_int("SQRT_DST" , alg_kind::eltwise_sqrt_use_dst_for_bwd); |
1224 | kernel_ctx.define_int("EXP_DST" , alg_kind::eltwise_exp_use_dst_for_bwd); |
1225 | kernel_ctx.define_int( |
1226 | "CLIP_V2_DST" , alg_kind::eltwise_clip_v2_use_dst_for_bwd); |
1227 | } |
1228 | |
1229 | inline bool post_ops_with_binary_ok(const primitive_attr_t *attr, |
1230 | const data_type_t dst_dt, const int max_ndims_supported = 2, |
1231 | const int prelu_mask_supported = 3) { |
1232 | const auto &p = attr->post_ops_; |
1233 | |
1234 | auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(false); }; |
1235 | auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); }; |
1236 | auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); }; |
1237 | auto is_prelu = [&](int idx) { return p.entry_[idx].is_prelu(); }; |
1238 | |
1239 | bool is_po_ok = true; |
1240 | for (int po_idx = 0; po_idx < p.len(); ++po_idx) { |
1241 | is_po_ok = is_po_ok |
1242 | && (is_eltwise(po_idx) || is_sum(po_idx) || is_binary(po_idx) |
1243 | || is_prelu(po_idx)); |
1244 | if (is_binary(po_idx)) { |
1245 | const auto &bin_desc = p.entry_[po_idx].binary.src1_desc; |
1246 | if (bin_desc.ndims > max_ndims_supported) { |
1247 | // accept descriptor if unsupported dims are equal to 1. |
1248 | for (int dim_idx = max_ndims_supported; |
1249 | dim_idx < bin_desc.ndims; ++dim_idx) { |
1250 | if (bin_desc.dims[dim_idx] != 1) is_po_ok = false; |
1251 | } |
1252 | } |
1253 | } |
1254 | if (is_prelu(po_idx)) { |
1255 | if (p.entry_[po_idx].prelu.mask > prelu_mask_supported) |
1256 | is_po_ok = false; |
1257 | } |
1258 | if (is_sum(po_idx)) { |
1259 | if (p.entry_[po_idx].sum.zero_point != 0) return false; |
1260 | if (p.entry_[po_idx].sum.dt != dnnl_data_type_undef |
1261 | && types::data_type_size(p.entry_[po_idx].sum.dt) |
1262 | != types::data_type_size(dst_dt)) |
1263 | return false; |
1264 | } |
1265 | } |
1266 | |
1267 | if (p.len() > MAX_POST_OPS_SUPPORTED) is_po_ok = false; |
1268 | |
1269 | return is_po_ok; |
1270 | } |
1271 | |
1272 | inline void def_post_ops_cfg(compute::kernel_ctx_t &kernel_ctx, |
1273 | const post_ops_t &post_ops, const dnnl_dims_t *dst_dims) { |
1274 | const int po_nop_id = 0; |
1275 | const int po_binary_id = 1; |
1276 | const int po_eltwise_id = 2; |
1277 | const int po_sum_id = 3; |
1278 | |
1279 | kernel_ctx.define_int("PO_BINARY" , po_binary_id); |
1280 | kernel_ctx.define_int("PO_ELTWISE" , po_eltwise_id); |
1281 | kernel_ctx.define_int("PO_SUM" , po_sum_id); |
1282 | |
1283 | std::string po_kernel_args = "-DPOST_OP_ARGS=\"" ; |
1284 | int nof_supported_post_ops = 0; |
1285 | |
1286 | auto add_po_defines = [&](const std::string &bin_arg_name, |
1287 | const post_ops_t::entry_t &e, int idx) { |
1288 | if (e.is_binary()) { |
1289 | kernel_ctx.define_int( |
1290 | "PO_" + std::to_string(idx) + "_KIND" , po_binary_id); |
1291 | kernel_ctx.define_int( |
1292 | "PO_" + std::to_string(idx) + "_ALG" , e.binary.alg); |
1293 | |
1294 | const memory_desc_wrapper src1_mdw(e.binary.src1_desc); |
1295 | const auto mdi = memory_desc_info_t::create(src1_mdw); |
1296 | def_memory_desc_info(kernel_ctx, mdi, bin_arg_name.c_str()); |
1297 | if (mdi.data_type == data_type::bf16) { |
1298 | kernel_ctx.define_int( |
1299 | "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16" , 1); |
1300 | } else { |
1301 | kernel_ctx.define_int( |
1302 | "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16" , 0); |
1303 | } |
1304 | } else if (e.is_prelu()) { |
1305 | // binary && eltwise relu = prelu post op |
1306 | kernel_ctx.define_int( |
1307 | "PO_" + std::to_string(idx) + "_KIND" , po_binary_id); |
1308 | kernel_ctx.define_int("PO_" + std::to_string(idx) + "_ALG" , |
1309 | alg_kind_t::dnnl_eltwise_relu); |
1310 | |
1311 | assert(dst_dims != nullptr); |
1312 | |
1313 | memory_desc_t weight_mem_desc; |
1314 | dims_t weight_dims {}; |
1315 | format_tag_t weights_tag; |
1316 | int weight_ndims = 0; |
1317 | if (e.prelu.mask == 0) { |
1318 | weight_ndims = 1; |
1319 | weight_dims[0] = 1; |
1320 | weights_tag = format_tag_t::dnnl_a; |
1321 | } else { |
1322 | // prelu weights are assumed to be upto 5 dims |
1323 | for (int d = 0; d < 5; ++d) { |
1324 | if (((e.prelu.mask >> d) & 0x1) == 1) { |
1325 | weight_ndims = d + 1; |
1326 | weight_dims[d] = (*dst_dims)[d]; |
1327 | } else { |
1328 | weight_dims[d] = 1; |
1329 | } |
1330 | } |
1331 | switch (weight_ndims) { |
1332 | case 1: weights_tag = format_tag_t::dnnl_a; break; |
1333 | case 2: weights_tag = format_tag_t::dnnl_ab; break; |
1334 | case 3: weights_tag = format_tag_t::dnnl_acb; break; |
1335 | case 4: weights_tag = format_tag_t::dnnl_acdb; break; |
1336 | case 5: weights_tag = format_tag_t::dnnl_acdeb; break; |
1337 | default: |
1338 | weights_tag = format_tag_t::dnnl_format_tag_undef; |
1339 | break; |
1340 | } |
1341 | } |
1342 | memory_desc_init_by_tag(weight_mem_desc, weight_ndims, weight_dims, |
1343 | data_type_t::dnnl_f32, weights_tag); |
1344 | const memory_desc_wrapper weight_mdw(weight_mem_desc); |
1345 | const auto mdi = memory_desc_info_t::create(weight_mdw); |
1346 | def_memory_desc_info(kernel_ctx, mdi, bin_arg_name.c_str()); |
1347 | |
1348 | // prelu weights are assumed to be f32 |
1349 | kernel_ctx.define_int( |
1350 | "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16" , 0); |
1351 | } else { |
1352 | memory_desc_t empty_mem_desc; |
1353 | dnnl_dims_t empty_dims = {1, 1, 1, 1}; |
1354 | memory_desc_init_by_tag(empty_mem_desc, 4, empty_dims, |
1355 | data_type_t::dnnl_s8, format_tag_t::dnnl_nchw); |
1356 | const memory_desc_wrapper src1_mdw(empty_mem_desc); |
1357 | const auto mdi = memory_desc_info_t::create(src1_mdw); |
1358 | def_memory_desc_info(kernel_ctx, mdi, bin_arg_name.c_str()); |
1359 | kernel_ctx.define_int( |
1360 | "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16" , 0); |
1361 | } |
1362 | if (e.is_eltwise(false)) { |
1363 | kernel_ctx.define_int( |
1364 | "PO_" + std::to_string(idx) + "_KIND" , po_eltwise_id); |
1365 | kernel_ctx.define_int( |
1366 | "PO_" + std::to_string(idx) + "_ALG" , e.eltwise.alg); |
1367 | kernel_ctx.define_float( |
1368 | ("PO_" + std::to_string(idx) + "_ELTWISE_ALPHA" ).c_str(), |
1369 | e.eltwise.alpha); |
1370 | kernel_ctx.define_float( |
1371 | ("PO_" + std::to_string(idx) + "_ELTWISE_BETA" ).c_str(), |
1372 | e.eltwise.beta); |
1373 | kernel_ctx.define_float( |
1374 | ("PO_" + std::to_string(idx) + "_ELTWISE_SCALE" ).c_str(), |
1375 | e.eltwise.scale); |
1376 | } else { |
1377 | kernel_ctx.define_float( |
1378 | ("PO_" + std::to_string(idx) + "_ELTWISE_ALPHA" ).c_str(), |
1379 | 1.0f); |
1380 | kernel_ctx.define_float( |
1381 | ("PO_" + std::to_string(idx) + "_ELTWISE_BETA" ).c_str(), |
1382 | 0.0f); |
1383 | kernel_ctx.define_float( |
1384 | ("PO_" + std::to_string(idx) + "_ELTWISE_SCALE" ).c_str(), |
1385 | 1.0f); |
1386 | } |
1387 | if (e.is_sum(false)) { |
1388 | kernel_ctx.define_int( |
1389 | "PO_" + std::to_string(idx) + "_KIND" , po_sum_id); |
1390 | kernel_ctx.define_int( |
1391 | "PO_" + std::to_string(idx) + "_ALG" , alg_kind::undef); |
1392 | kernel_ctx.define_float( |
1393 | ("PO_" + std::to_string(idx) + "_SUM_SCALE" ).c_str(), |
1394 | e.sum.scale); |
1395 | } else { |
1396 | kernel_ctx.define_float( |
1397 | ("PO_" + std::to_string(idx) + "_SUM_SCALE" ).c_str(), 1.0f); |
1398 | } |
1399 | if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false) |
1400 | || e.is_prelu())) { |
1401 | // empty post op |
1402 | kernel_ctx.define_int( |
1403 | "PO_" + std::to_string(idx) + "_KIND" , po_nop_id); |
1404 | // *_ALG need to be set but it's unused when kind is NOP |
1405 | kernel_ctx.define_int( |
1406 | "PO_" + std::to_string(idx) + "_ALG" , alg_kind::undef); |
1407 | --nof_supported_post_ops; |
1408 | } |
1409 | po_kernel_args += ", const __global PO_" + std::to_string(idx) |
1410 | + "_BIN_ARG_DATA_T *po_" + std::to_string(idx) + "_binary_arg" ; |
1411 | }; |
1412 | |
1413 | for (int idx = 0; idx < post_ops.len(); ++idx, ++nof_supported_post_ops) { |
1414 | const std::string bin_arg_name |
1415 | = "PO_" + std::to_string(idx) + "_BIN_ARG" ; |
1416 | add_po_defines(bin_arg_name, post_ops.entry_[idx], idx); |
1417 | } |
1418 | |
1419 | kernel_ctx.define_int("POST_OP_CHAIN_LENGTH" , nof_supported_post_ops); |
1420 | if (post_ops.len() > 0) { |
1421 | // due to C macro limitations on which post op service is build always |
1422 | // load bf16 conversion functions |
1423 | kernel_ctx.define_int("POST_OP_USING_BF16" , 1); |
1424 | } |
1425 | po_kernel_args += "\"" ; |
1426 | kernel_ctx.add_option(po_kernel_args); |
1427 | } |
1428 | |
1429 | inline int append_post_ops_to_arg_list_base(const exec_args_t &args, |
1430 | compute::kernel_arg_list_t &arg_list, int post_op_idx, |
1431 | const post_ops_t &post_ops) { |
1432 | auto set_arg_entry = [&](const post_ops_t::entry_t &e, int po_idx) { |
1433 | if (e.is_binary()) { |
1434 | auto arg = args.at( |
1435 | DNNL_ARG_ATTR_MULTIPLE_POST_OP(po_idx) | DNNL_ARG_SRC_1); |
1436 | assert(arg.is_const); |
1437 | |
1438 | auto &binary_arg = arg.mem |
1439 | ? *(arg.mem->memory_storage()) |
1440 | : dnnl::impl::memory_storage_t::empty_storage(); |
1441 | arg_list.set(post_op_idx++, binary_arg); |
1442 | } else if (e.is_prelu()) { |
1443 | auto arg = args.at( |
1444 | DNNL_ARG_ATTR_MULTIPLE_POST_OP(po_idx) | DNNL_ARG_WEIGHTS); |
1445 | assert(arg.is_const); |
1446 | auto &prelu_wei_arg = arg.mem |
1447 | ? *(arg.mem->memory_storage()) |
1448 | : dnnl::impl::memory_storage_t::empty_storage(); |
1449 | arg_list.set(post_op_idx++, prelu_wei_arg); |
1450 | } else { |
1451 | arg_list.set(post_op_idx++, memory_storage_t::empty_storage()); |
1452 | } |
1453 | }; |
1454 | |
1455 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
1456 | set_arg_entry(post_ops.entry_[idx], idx); |
1457 | } |
1458 | return post_op_idx; |
1459 | } |
1460 | inline int append_post_ops_to_arg_list_gemm(const exec_args_t &args, |
1461 | compute::kernel_arg_list_t &arg_list, int post_op_idx, |
1462 | const post_ops_t &post_ops) { |
1463 | return append_post_ops_to_arg_list_base( |
1464 | args, arg_list, post_op_idx, post_ops); |
1465 | } |
1466 | inline int append_post_ops_to_arg_list(const exec_ctx_t &ctx, |
1467 | compute::kernel_arg_list_t &arg_list, int post_op_idx, |
1468 | const post_ops_t &post_ops) { |
1469 | exec_args_t args; |
1470 | return append_post_ops_to_arg_list_base( |
1471 | ctx.args(), arg_list, post_op_idx, post_ops); |
1472 | } |
1473 | |
1474 | inline bool post_ops_preserves_zeroes( |
1475 | const exec_ctx_t &ctx, const post_ops_t &post_ops) { |
1476 | bool preserve_zeroes = true; |
1477 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
1478 | const post_ops_t::entry_t &po_entry = post_ops.entry_[idx]; |
1479 | if (po_entry.is_binary()) { |
1480 | // only binary mul is preserving zeroes |
1481 | preserve_zeroes &= po_entry.binary.alg |
1482 | == dnnl::impl::alg_kind_t::dnnl_binary_mul; |
1483 | } |
1484 | if (po_entry.is_eltwise(false)) { |
1485 | preserve_zeroes &= gpu_eltwise_fwd_pd_t::eltwise_preserves_zero( |
1486 | po_entry.eltwise.alg, po_entry.eltwise.alpha, |
1487 | po_entry.eltwise.beta); |
1488 | } |
1489 | } |
1490 | return preserve_zeroes; |
1491 | } |
1492 | |
1493 | inline void def_attr_info(compute::kernel_ctx_t &kernel_ctx, |
1494 | const attr_info_t &attr_info, const post_ops_t &post_ops, |
1495 | const dnnl_dims_t *dst_dims = nullptr) { |
1496 | assert(attr_info.initialized); |
1497 | |
1498 | kernel_ctx.define_int("WITH_POST_OP" , post_ops.len() > 0); |
1499 | |
1500 | kernel_ctx.define_int("WITH_ELTWISE" , attr_info.with_eltwise); |
1501 | kernel_ctx.define_int("ELTWISE_IDX" , attr_info.eltwise_idx); |
1502 | kernel_ctx.define_int("ELTWISE_ALG" , attr_info.eltwise_alg); |
1503 | |
1504 | kernel_ctx.define_int("WITH_SUM" , attr_info.with_sum); |
1505 | kernel_ctx.define_int("SUM_IDX" , attr_info.sum_idx); |
1506 | kernel_ctx.define_int("SUM_SCALE" , attr_info.sum_scale); |
1507 | kernel_ctx.define_int("SUM_SCALE1" , attr_info.sum_scale == 1.0f); |
1508 | |
1509 | kernel_ctx.define_int("WITH_SRC0_SCALE" , attr_info.with_src0_scale); |
1510 | kernel_ctx.define_int("WITH_SRC1_SCALE" , attr_info.with_src1_scale); |
1511 | |
1512 | kernel_ctx.define_int("WITH_SCALES" , attr_info.with_oscales); |
1513 | kernel_ctx.define_int( |
1514 | "WITH_RUNTIME_SCALES" , attr_info.with_runtime_oscales); |
1515 | kernel_ctx.define_int("SCALES_PER_OC" , attr_info.with_per_oc_oscales); |
1516 | kernel_ctx.define_int("SCALES_COMMON" , attr_info.with_common_oscales); |
1517 | |
1518 | kernel_ctx.define_int("WITH_SRC_SCALES" , attr_info.with_src_scales); |
1519 | kernel_ctx.define_int("WITH_WEI_SCALES" , attr_info.with_wei_scales); |
1520 | kernel_ctx.define_int("WITH_DST_SCALES" , attr_info.with_dst_scales); |
1521 | kernel_ctx.define_int("WEI_SCALES_MASK" , attr_info.wei_scales_mask); |
1522 | |
1523 | kernel_ctx.define_int("WITH_SRC_ZPOINTS" , attr_info.with_src_zpoints); |
1524 | kernel_ctx.define_int("WITH_WEI_ZPOINTS" , attr_info.with_wei_zpoints); |
1525 | kernel_ctx.define_int("WITH_DST_ZPOINTS" , attr_info.with_dst_zpoints); |
1526 | kernel_ctx.define_int( |
1527 | "WITH_SRC_ZPOINTS_PER_IC" , attr_info.with_per_ic_src_zpoints); |
1528 | kernel_ctx.define_int( |
1529 | "WITH_DST_ZPOINTS_PER_OC" , attr_info.with_per_oc_dst_zpoints); |
1530 | |
1531 | def_binary_alg_kinds(kernel_ctx); |
1532 | def_eltwise_alg_kinds(kernel_ctx); |
1533 | |
1534 | def_post_ops_cfg(kernel_ctx, post_ops, dst_dims); |
1535 | } |
1536 | |
1537 | inline void def_dispatch(compute::kernel_ctx_t &kernel_ctx, |
1538 | const compute::dispatch_t &dispatch) { |
1539 | dispatch.def_kernel_macros(kernel_ctx); |
1540 | } |
1541 | |
1542 | inline void maybe_fix_non_uniform_work_sizes( |
1543 | bool has_non_uniform_wg, conv_conf_t &conf) { |
1544 | for (int i = 0; i < 3; i++) { |
1545 | conf.gws_orig_d[i] = conf.gws_d[i]; |
1546 | if (!has_non_uniform_wg) |
1547 | conf.gws_d[i] = utils::rnd_up(conf.gws_d[i], conf.lws_d[i]); |
1548 | } |
1549 | } |
1550 | |
1551 | inline void bwd_w_compute_block_sizes(conv_conf_t &conf, engine_t *engine) { |
1552 | const bool is_1stconv = conf.ic_without_padding == 3; |
1553 | |
1554 | if (conf.is_depthwise) { |
1555 | conf.odb = 1; |
1556 | conf.ohb = 1; |
1557 | conf.owb = utils::rnd_up(conf.ow, conf.ow_block); |
1558 | conf.ocb = 1; |
1559 | conf.icb = 1; |
1560 | conf.osp_chunk = utils::div_up(conf.od, conf.odb) |
1561 | * utils::div_up(conf.oh, conf.ohb) |
1562 | * utils::div_up(conf.ow, conf.owb); |
1563 | |
1564 | conf.mb_chunk = utils::div_up(conf.mb, conf.mb_block); |
1565 | conf.nchunk = conf.osp_chunk * conf.mb_chunk; |
1566 | return; |
1567 | } |
1568 | auto *dev_info = utils::downcast<compute::compute_engine_t *>(engine) |
1569 | ->device_info(); |
1570 | int hw_threads = dev_info->hw_threads(); |
1571 | size_t llc_bytes = dev_info->llc_cache_size(); |
1572 | |
1573 | auto next_candidate = [](int size, int block) { |
1574 | if (size == block) return block; |
1575 | // If size is big enough, then do not care about the remainder. |
1576 | if (block * 16 < size) return block + 1; |
1577 | // Otherwise search for the next divisor. |
1578 | block++; |
1579 | while (size % block != 0) |
1580 | block++; |
1581 | return block; |
1582 | }; |
1583 | |
1584 | int mb_nb = 1; |
1585 | conf.odb = 1; |
1586 | conf.ohb = 1; |
1587 | conf.owb = 1; |
1588 | |
1589 | int mb_nblk = utils::div_up(conf.mb, conf.mb_block); |
1590 | int ic_nblk = utils::div_up(conf.ic, conf.ic_block); |
1591 | int oc_nblk = utils::div_up(conf.oc, conf.oc_block); |
1592 | |
1593 | int ic_nb_max = is_1stconv ? 1 : nstl::min(ic_nblk, 16); |
1594 | int oc_nb_max = nstl::min(oc_nblk, 16); |
1595 | int ic_nb = is_1stconv ? 1 : utils::max_div(ic_nblk, ic_nb_max); |
1596 | int oc_nb = utils::max_div(oc_nblk, oc_nb_max); |
1597 | |
1598 | int mb_nb_max = 1; |
1599 | if (!is_1stconv && (conf.mb_block == 1) && (conf.ic % 1024 != 0) |
1600 | && (conf.oc % 1024 != 0)) { |
1601 | mb_nb_max = 4; |
1602 | } |
1603 | |
1604 | auto get_nthr = [&]() { |
1605 | int nthr = utils::div_up(mb_nblk, mb_nb) |
1606 | * utils::div_up(conf.od, conf.odb) |
1607 | * utils::div_up(conf.oh, conf.ohb) |
1608 | * utils::div_up(conf.ow, conf.owb) * conf.kh * conf.kw * conf.kd |
1609 | * oc_nblk * (is_1stconv ? 1 : ic_nblk) * conf.ngroups; |
1610 | return nthr; |
1611 | }; |
1612 | |
1613 | auto get_src_dst_size = [&]() { |
1614 | int iwb = conf.ndims < 3 ? 1 : conf.owb + 2 * (conf.kw - 1); |
1615 | int ihb = conf.ndims < 4 ? 1 : conf.ohb + 2 * (conf.kh - 1); |
1616 | int idb = conf.ndims < 5 ? 1 : conf.odb + 2 * (conf.kd - 1); |
1617 | |
1618 | size_t ispb = iwb * ihb * idb; |
1619 | size_t ospb = conf.owb * conf.ohb * conf.odb; |
1620 | size_t src_size = sizeof(float) * conf.mb_block |
1621 | * (is_1stconv ? conf.ic : ic_nb * conf.ic_block) * ispb; |
1622 | size_t dst_size = sizeof(float) * conf.mb_block |
1623 | * (oc_nb * conf.oc_block) * ospb; |
1624 | |
1625 | int nthr_per_spb |
1626 | = conf.kh * conf.kw * conf.kd * ic_nb * oc_nb * conf.ngroups; |
1627 | size_t sz = (size_t)(src_size + dst_size); |
1628 | if (nthr_per_spb < hw_threads) sz = sz * hw_threads / nthr_per_spb; |
1629 | return sz; |
1630 | }; |
1631 | |
1632 | auto try_next = [&](int &v, int next) { |
1633 | if (next <= v) return false; |
1634 | int v_old = v; |
1635 | v = next; |
1636 | // Heuristics: |
1637 | // - src and dst size accessed in the inner loops should fit LLC |
1638 | // - Require at least (3 * hw_threads) to spawn to have enough |
1639 | // parallelism |
1640 | if (get_src_dst_size() > llc_bytes || get_nthr() < 3 * hw_threads) { |
1641 | v = v_old; |
1642 | return false; |
1643 | } |
1644 | return true; |
1645 | }; |
1646 | |
1647 | if (utils::one_of(conf.ver, ver_nhwc, ver_8ow16c, ver_1stconv)) |
1648 | conf.owb = conf.ow_block; |
1649 | |
1650 | // Increase spatial tile size as much as possible. |
1651 | for (int i = 0; i < 128; i++) { |
1652 | int owb_next; |
1653 | if (utils::one_of(conf.ver, ver_nhwc, ver_8ow16c, ver_1stconv)) { |
1654 | int ow_padded = utils::rnd_up(conf.ow, conf.ow_block); |
1655 | owb_next = conf.ow_block |
1656 | * next_candidate(ow_padded / conf.ow_block, |
1657 | conf.owb / conf.ow_block); |
1658 | } else { |
1659 | owb_next = next_candidate(conf.ow, conf.owb); |
1660 | } |
1661 | try_next(conf.owb, owb_next); |
1662 | |
1663 | int ohb_next = next_candidate(conf.oh, conf.ohb); |
1664 | try_next(conf.ohb, ohb_next); |
1665 | |
1666 | int odb_next = next_candidate(conf.od, conf.odb); |
1667 | try_next(conf.odb, odb_next); |
1668 | |
1669 | int mb_nb_next = next_candidate(mb_nb_max, mb_nb); |
1670 | try_next(mb_nb, mb_nb_next); |
1671 | } |
1672 | |
1673 | conf.icb = is_1stconv ? conf.ic : ic_nb * conf.ic_block; |
1674 | conf.ocb = oc_nb * conf.oc_block; |
1675 | |
1676 | conf.osp_chunk = utils::div_up(conf.od, conf.odb) |
1677 | * utils::div_up(conf.oh, conf.ohb) |
1678 | * utils::div_up(conf.ow, conf.owb); |
1679 | |
1680 | conf.mb_chunk = utils::div_up(mb_nblk, mb_nb); |
1681 | |
1682 | conf.nchunk = conf.mb_chunk * conf.osp_chunk; |
1683 | } |
1684 | |
1685 | } // namespace gpu |
1686 | } // namespace impl |
1687 | } // namespace dnnl |
1688 | |
1689 | #endif |
1690 | |