1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_PRIMITIVE_CONF_HPP
18#define GPU_PRIMITIVE_CONF_HPP
19
20#include <stdint.h>
21
22#include "common/c_types_map.hpp"
23#include "common/memory_desc_wrapper.hpp"
24#include "common/memory_storage.hpp"
25#include "common/primitive_attr.hpp"
26#include "common/primitive_exec_types.hpp"
27#include "common/utils.hpp"
28#include "gpu/compute/compute.hpp"
29#include "gpu/gpu_eltwise_pd.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34
35#define MAX_NDIMS 6
36#define MAX_POST_OPS_SUPPORTED 32
37
38inline bool memory_desc_ndims_ok(const memory_desc_t *md) {
39 return md->ndims > MAX_NDIMS;
40}
41
42template <typename T, typename... Rest>
43bool memory_desc_ndims_ok(const T *first, const Rest *... rest) {
44 return memory_desc_ndims_ok(first) || memory_desc_ndims_ok(rest...);
45}
46
47inline dim_t get_attr_oscales_count(int mask, const memory_desc_wrapper &md) {
48 dim_t count = 1;
49 if (mask == 0) return count;
50
51 for (int d = 0; d < md.ndims(); d++) {
52 const int dim_mask = 1 << d;
53 if (dim_mask & mask) count *= md.dims()[d];
54 }
55
56 return count;
57}
58
59struct memory_desc_info_t {
60 // Max levels of blocking
61 static const int max_nlevels = 3;
62
63 int ndims;
64 data_type_t data_type;
65
66 int offset0;
67 int dims[MAX_NDIMS];
68 int padded_dims[MAX_NDIMS];
69
70 int nlevels;
71 int blocks[MAX_NDIMS][max_nlevels + 1];
72 int strides[MAX_NDIMS][max_nlevels + 1];
73
74 static memory_desc_info_t create(const memory_desc_wrapper &mdw) {
75 using namespace format_tag;
76
77 auto md_info = memory_desc_info_t();
78
79 md_info.nlevels = 2;
80
81 md_info.ndims = mdw.ndims();
82 md_info.data_type = mdw.data_type();
83 md_info.offset0 = mdw.offset0();
84
85 auto &blk = mdw.blocking_desc();
86 dim_t blk_stride
87 = utils::array_product(blk.inner_blks, blk.inner_nblks);
88
89 for (int d = 0; d < mdw.ndims(); ++d) {
90 utils::array_set(md_info.blocks[d], 1, md_info.nlevels + 1);
91 utils::array_set(md_info.strides[d], 0, md_info.nlevels + 1);
92 }
93
94 for (int d = 0; d < mdw.ndims(); ++d) {
95 md_info.dims[d] = mdw.dims()[d];
96 md_info.padded_dims[d] = mdw.padded_dims()[d];
97 md_info.strides[d][0] = blk.strides[d];
98 }
99
100 int levels[MAX_NDIMS] = {0};
101 for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
102 int d = blk.inner_idxs[iblk];
103 ++levels[d];
104
105 md_info.blocks[d][levels[d]] = blk.inner_blks[iblk];
106 blk_stride /= blk.inner_blks[iblk];
107 md_info.strides[d][levels[d]] = blk_stride;
108 }
109 return md_info;
110 }
111};
112
113struct attr_info_t {
114 static attr_info_t create(const primitive_attr_t *attr) {
115 const auto &po = attr->post_ops_;
116
117 attr_info_t attr_info;
118
119 attr_info.binary_idx = po.find(primitive_kind::binary);
120 attr_info.with_binary = (attr_info.binary_idx != -1);
121
122 // Eltwise
123 attr_info.eltwise_idx = po.find(primitive_kind::eltwise);
124 attr_info.with_eltwise = (attr_info.eltwise_idx != -1);
125
126 if (attr_info.with_eltwise) {
127 auto &eltwise = po.entry_[attr_info.eltwise_idx].eltwise;
128 attr_info.eltwise_alg = eltwise.alg;
129 attr_info.eltwise_scale = eltwise.scale;
130 attr_info.eltwise_alpha = eltwise.alpha;
131 attr_info.eltwise_beta = eltwise.beta;
132 } else {
133 attr_info.eltwise_alg = alg_kind::undef;
134 attr_info.eltwise_scale = 1.0f;
135 attr_info.eltwise_alpha = 1.0f;
136 attr_info.eltwise_beta = 0.0f;
137 }
138
139 // Sum
140 attr_info.sum_idx = po.find(primitive_kind::sum);
141 attr_info.sum_scale = (attr_info.sum_idx != -1
142 ? po.entry_[attr_info.sum_idx].sum.scale
143 : 0.0f);
144 attr_info.sum_data_type = (attr_info.sum_idx != -1)
145 ? po.entry_[attr_info.sum_idx].sum.dt
146 : dnnl_data_type_undef;
147 attr_info.with_sum
148 = (attr_info.sum_idx != -1) && (attr_info.sum_scale != 0.0f);
149
150 // Output scales
151 attr_info.with_oscales
152 = !attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values();
153
154 const auto &scales_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_;
155 attr_info.with_common_oscales
156 = attr_info.with_oscales && (scales_mask == 0);
157 attr_info.with_per_oc_oscales
158 = attr_info.with_oscales && (scales_mask == (1 << 1));
159
160 attr_info.with_runtime_oscales = !attr->output_scales_.defined();
161
162 const auto &src0_scales = attr->scales_.get(DNNL_ARG_SRC_0);
163 attr_info.with_src0_scale = !src0_scales.has_default_values();
164 assert(src0_scales.mask_ == 0);
165
166 const auto &src1_scales = attr->scales_.get(DNNL_ARG_SRC_1);
167 attr_info.with_src1_scale = !src1_scales.has_default_values();
168 assert(src1_scales.mask_ == 0);
169
170 const auto &src_scales = attr->scales_.get(DNNL_ARG_SRC);
171 attr_info.with_src_scales = !src_scales.has_default_values();
172 assert(src_scales.mask_ == 0);
173
174 const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS);
175 attr_info.with_wei_scales = !wei_scales.has_default_values();
176 attr_info.wei_scales_mask = wei_scales.mask_;
177
178 const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST);
179 attr_info.with_dst_scales = !dst_scales.has_default_values();
180 assert(dst_scales.mask_ == 0);
181
182 // zero points
183 const auto &zp = attr->zero_points_;
184 attr_info.with_src_zpoints = !zp.has_default_values(DNNL_ARG_SRC);
185 attr_info.with_wei_zpoints = !zp.has_default_values(DNNL_ARG_WEIGHTS);
186 attr_info.with_dst_zpoints = !zp.has_default_values(DNNL_ARG_DST);
187
188 attr_info.with_per_ic_src_zpoints = attr_info.with_src_zpoints
189 && !zp.defined(DNNL_ARG_SRC) && !zp.common(DNNL_ARG_SRC);
190
191 attr_info.with_per_oc_dst_zpoints = attr_info.with_dst_zpoints
192 && !zp.defined(DNNL_ARG_DST) && !zp.common(DNNL_ARG_DST);
193
194 attr_info.initialized = true;
195 return attr_info;
196 }
197
198 bool initialized = false;
199
200 bool with_binary;
201 bool with_eltwise;
202 int eltwise_idx;
203 int binary_idx;
204 alg_kind_t eltwise_alg;
205 float eltwise_scale;
206 float eltwise_alpha;
207 float eltwise_beta;
208
209 bool with_sum;
210 int sum_idx;
211 float sum_scale;
212 data_type_t sum_data_type;
213
214 bool with_oscales;
215 bool with_common_oscales;
216 bool with_per_oc_oscales;
217 bool with_runtime_oscales;
218
219 bool with_src0_scale;
220 bool with_src1_scale;
221 bool with_src_scales;
222 bool with_wei_scales;
223 bool with_dst_scales;
224 bool wei_scales_mask;
225
226 bool with_src_zpoints;
227 bool with_wei_zpoints;
228 bool with_dst_zpoints;
229 bool with_per_ic_src_zpoints;
230 bool with_per_oc_dst_zpoints;
231};
232
233struct offsets_t {
234 int src_off[4][MAX_NDIMS];
235 int wei_off[4][MAX_NDIMS];
236 int dst_off[4][MAX_NDIMS];
237};
238
239struct rnn_offsets_t {
240 int src_layer_off[4][MAX_NDIMS];
241 int src_iter_off[4][MAX_NDIMS];
242 int src_iter_c_off[4][MAX_NDIMS];
243 int weights_layer_off[4][MAX_NDIMS];
244 int weights_iter_off[4][MAX_NDIMS];
245 int bias_off[4][MAX_NDIMS];
246 int dst_layer_off[4][MAX_NDIMS];
247 int dst_iter_off[4][MAX_NDIMS];
248 int dst_iter_c_off[4][MAX_NDIMS];
249 int diff_src_layer_off[4][MAX_NDIMS];
250 int diff_src_iter_off[4][MAX_NDIMS];
251 int diff_src_iter_c_off[4][MAX_NDIMS];
252 int diff_weights_layer_off[4][MAX_NDIMS];
253 int diff_weights_iter_off[4][MAX_NDIMS];
254 int diff_bias_off[4][MAX_NDIMS];
255 int diff_dst_layer_off[4][MAX_NDIMS];
256 int diff_dst_iter_off[4][MAX_NDIMS];
257 int diff_dst_iter_c_off[4][MAX_NDIMS];
258 int ws_off[4][MAX_NDIMS];
259};
260
261// Convolution
262enum conv_version_t {
263 ver_unused,
264 ver_1stconv,
265 ver_16mb16c,
266 ver_32mb16c,
267 ver_32mb32c,
268 ver_32c,
269 ver_8ow16c,
270 ver_nhwc,
271 ver_nchw,
272 ver_mb_block,
273 ver_ow_block,
274
275 // Xe_HP-specific versions.
276 ver_v1,
277 ver_v2
278};
279
280struct conv_conf_t {
281 prop_kind_t prop_kind;
282
283 int ndims;
284 int mb;
285 int ngroups, ic, oc;
286 int ngroups_without_padding, oc_without_padding, ic_without_padding;
287 int id, ih, iw, od, oh, ow;
288 int f_pad, l_pad, t_pad;
289 int back_pad, r_pad, b_pad;
290 int kd, kh, kw, kwb;
291 int stride_d, stride_h, stride_w;
292 int dilate_d, dilate_h, dilate_w;
293
294 int sp_block, sp;
295 int od_block, oh_block, ow_block;
296 int id_block, ih_block, iw_block;
297 int oc_block, ic_block, nchunk;
298 int omb;
299 int odb, ohb, owb;
300 int icb;
301 int ocb;
302 int osp_chunk, mb_chunk, mb_block;
303 int iw_tail;
304 size_t wei_slm_size, src_slm_size, dst_slm_size;
305 int sub_group_size;
306
307 size_t gws_d[3], lws_d[3];
308 // Original global work sizes, before applying rounding in case when
309 // non-uniform work-groups are not supported.
310 size_t gws_orig_d[3];
311 compute::dispatch_t dispatch;
312
313 bool with_bias, with_groups;
314
315 attr_info_t attr_info;
316
317 bool is_depthwise;
318 bool is_nhwc;
319 bool reorder_wei = false;
320 bool reorder_bias = false;
321 int ver;
322 format_tag_t src_tag, dst_tag, wei_tag;
323 bool is_nchw;
324 bool is_src_nchw, is_src_nhwc;
325 bool is_dst_nhwc;
326
327 int tile_size;
328 int wino_m;
329 int wino_r;
330 int wino_ih, wino_oh;
331 int wino_iw, wino_ow;
332 int wino_ic;
333 int wino_oc;
334 int wino_ic_block;
335 int wino_oc_block;
336 int vect_size;
337 size_t U_gws_d[3], U_lws_d[3];
338 size_t V_gws_d[3], V_lws_d[3];
339 size_t M_gws_d[3], M_lws_d[3];
340 bool is_fused;
341
342 data_type_t src_data_type;
343 data_type_t weights_data_type;
344 data_type_t bias_data_type;
345 data_type_t dst_data_type;
346 data_type_t acc_data_type;
347
348 memory_desc_info_t src_md_info;
349 memory_desc_info_t wei_md_info;
350 memory_desc_info_t dst_md_info;
351};
352
353// Pooling
354struct pool_conf_t {
355 int ndims;
356 int mb, c;
357 int mb_padded;
358 int c_padded;
359 int id, ih, iw, od, oh, ow;
360 int stride_d, stride_h, stride_w;
361 int kd, kh, kw;
362 int dd, dh, dw;
363 int f_pad, t_pad, l_pad;
364 data_type_t src_dt;
365 data_type_t dst_dt;
366 alg_kind_t alg;
367 bool is_plain;
368 bool is_training, is_backward;
369 bool use_mb_c_block, use_only_c_block;
370 bool unroll_mb = false;
371 int chunks_per_c_block, chunks_per_mb_block;
372 int vect_dt_n;
373 int nvect;
374 compute::dispatch_t dispatch;
375 int sub_group_size;
376 int global_pool_spatial_chunk;
377 int num_batches = 1;
378 int mb_block_size = 16;
379
380 attr_info_t attr_info;
381 memory_desc_info_t src_md_info;
382 memory_desc_info_t dst_md_info;
383};
384
385// Prelu
386struct prelu_conf_t {
387 bool is_forward;
388 bool reduce_diff_weights;
389 compute::dispatch_t dispatch;
390
391 attr_info_t attr_info;
392 memory_desc_info_t src_md_info;
393 memory_desc_info_t wei_md_info;
394 memory_desc_info_t dst_md_info;
395 memory_desc_info_t diff_src_md_info;
396 memory_desc_info_t diff_wei_md_info;
397};
398
399// Inner Product
400struct inner_product_conf_t {
401 int ndims;
402 int src_ndims, wei_ndims, dst_ndims;
403 int mb, oc, ic, ic_total;
404 int id, ih, iw, od, oh, ow;
405 int kd, kh, kw;
406 bool with_bias, has_spatial;
407 bool is_forward, is_backward_data, is_backward_weights;
408 compute::dispatch_t dispatch;
409 bool reorder_dst = false;
410
411 data_type_t src_dt;
412 data_type_t wei_dt;
413 data_type_t bia_dt;
414 data_type_t dst_dt;
415 data_type_t acc_dt;
416
417 attr_info_t attr_info;
418};
419
420// RNN
421struct rnn_conf_t {
422 int cell_kind;
423 int activation_kind;
424 int direction_kind;
425 bool with_bias;
426 bool with_src_iter;
427 bool with_src_iter_c;
428 bool with_dst_iter;
429 bool with_dst_iter_c;
430 bool is_lbr;
431 bool is_vanilla_gru;
432 bool is_fwd;
433 bool copy_bias;
434 bool is_int8;
435 bool is_testmode;
436 bool is_training;
437 data_type_t src_dt;
438 data_type_t wei_dt;
439 data_type_t bia_dt;
440 data_type_t dst_dt;
441 data_type_t acc_dt;
442 data_type_t aux_dt;
443 data_type_t input_dt;
444 data_type_t output_dt;
445 data_type_t diff_dt;
446
447 int n_layer;
448 int n_dir;
449 int n_iter;
450 int n_iter_scratch_gates;
451 int n_gates;
452 int n_bias;
453 int n_states;
454 int n_weights_input;
455 int n_weights_state;
456 int batch;
457 int slc;
458 int sic;
459 int dhc;
460 int dlc;
461 int wic;
462 int arch_ld;
463 int n_parts_weights_iter, n_parts_weights_layer;
464 int src_layer_ndims;
465 int src_iter_ndims;
466 int src_iter_c_ndims;
467 int weights_layer_ndims;
468 int weights_iter_ndims;
469 int dst_layer_ndims;
470 int dst_iter_ndims;
471 int dst_iter_c_ndims;
472 int bias_ndims;
473 int diff_src_layer_ndims;
474 int diff_src_iter_ndims;
475 int diff_src_iter_c_ndims;
476 int diff_weights_layer_ndims;
477 int diff_weights_iter_ndims;
478 int diff_dst_layer_ndims;
479 int diff_dst_iter_ndims;
480 int diff_dst_iter_c_ndims;
481 int diff_bias_ndims;
482 int states_ws_ld, gates_ws_ld, scratch_diff_states_ld, scratch_gates_ld;
483
484 int wei_qparam_mask;
485
486 size_t ws_gates_offset;
487 size_t ws_states_offset;
488 size_t ws_grid_comp_offset;
489 size_t ws_h_state_offset;
490 size_t ws_c_state_offset;
491 size_t ws_bias_offset;
492 size_t scratchpad_size;
493 size_t scratch_dhG1_offset;
494 size_t scratch_gates_offset;
495 size_t scratch_cell_offset;
496 size_t scratch_diff_states_offset;
497 size_t workspace_size;
498};
499
500struct rnn_reorder_conf_t {
501 bool do_reorder, with_group, has_padding;
502 bool with_sum_ab, with_sum_a;
503 bool use_ref_impl;
504 int ndims;
505 size_t nelems;
506 compute::dispatch_t dispatch;
507 int block[3];
508 int sub_group_size;
509 int mask;
510 size_t scales_count;
511};
512
513// Batch Normalization
514struct bnorm_conf_t {
515 data_type_t data_type;
516
517 int ndims;
518 int mb, ic, mb_block, ic_block;
519 int reduce_dim_idx, reduce_dim;
520 int id, ih, iw;
521 int nn, sp, sp_tail, vect_size;
522 int stat_sp_nblocks, stat_sp_tail, stat_sp_block;
523 int reduce_stat_nblocks;
524 bool with_relu, use_16mb_unroll, use_nhwc;
525 int stat_ic;
526 bool is_forward, is_backward;
527 bool use_scale, use_shift, save_stats, is_training;
528 bool calculate_stats, calculate_diff_stats;
529 bool fuse_norm_relu, fuse_norm_add_relu;
530 bool diff_scale, diff_shift;
531 float relu_negative_slope, eps;
532 int sub_group_size;
533 bool vectorize_calc_stats;
534 bool skip_reduce_stat;
535 bool use_stats_one_pass;
536 bool nhwc_optimized;
537 int calc_stat_ic;
538 bool use_fused_atomics_reduction;
539
540 compute::dispatch_t dispatch_calc_stat;
541 compute::dispatch_t dispatch_reduce_stat;
542 compute::dispatch_t dispatch;
543 compute::dispatch_t dispatch_reduce_aux;
544};
545
546// Layer Normalization
547struct lnorm_conf_t {
548 data_type_t data_type;
549
550 bool is_fwd;
551 int ndims;
552 int norm_axis;
553
554 memory_desc_info_t src_md_info;
555 memory_desc_info_t dst_md_info;
556 memory_desc_info_t stat_md_info;
557
558 bool use_scale;
559 bool use_shift;
560 bool calculate_stats;
561 bool save_stats;
562 bool vectorize_calc_stats;
563 bool vectorize_bwd;
564 bool vectorize_bwd_scaleshift;
565 float eps;
566 int sub_group_size;
567 int vect_dt_n;
568 int shift_off;
569 int n_chunk_size;
570 int n_chunks;
571 int vector_size_scaleshift;
572
573 compute::dispatch_t dispatch_scaleshift;
574 compute::dispatch_t dispatch_scaleshift_finalize;
575 compute::dispatch_t dispatch;
576};
577
578// Binary
579struct binary_conf_t {
580 int ndims, nvect;
581 bool use_unroll_16b, src0_unroll_16b;
582 bool is_plain_layout;
583 bool plain_to_ABcd4a4b;
584 bool isXa16b;
585 data_type_t src0_data_type;
586 data_type_t src1_data_type;
587 data_type_t dst_data_type;
588 bool is_mul;
589 bool is_add;
590 bool is_max;
591 bool is_min;
592 bool is_div;
593 bool is_sub;
594 bool is_ge;
595 bool is_gt;
596 bool is_le;
597 bool is_lt;
598 bool is_eq;
599 bool is_ne;
600 bool is_tensor_op;
601 compute::dispatch_t dispatch;
602 int mb_block;
603 int dim0[MAX_NDIMS];
604 int src0_bcast_dims[MAX_NDIMS];
605 int src1_bcast_dims[MAX_NDIMS];
606 bool is_dense;
607 bool is_same_md;
608 bool same_src_dt;
609 bool with_binary_post_op;
610
611 memory_desc_info_t src0_md_info;
612 memory_desc_info_t src1_md_info;
613 memory_desc_info_t dst_md_info;
614
615 attr_info_t attr_info;
616};
617
618// Reduction
619struct reduction_phase_t {
620 data_type_t src_type, dst_type;
621 compute::nd_range_t nd_range;
622 compute::kernel_t kernel;
623 int reduction_size, initial_size, final_size;
624 int num_reduction_chunks;
625 bool is_final, is_first;
626};
627
628struct reduction_conf_t {
629 // Used by reference implementation
630 alg_kind_t alg;
631 int ndims, power, div;
632 float eps;
633 dim_t src_dims[MAX_NDIMS], reduce_dims[MAX_NDIMS], dst_dims[MAX_NDIMS];
634 bool is_reduction_dim[MAX_NDIMS];
635 data_type_t src_type, dst_type;
636 memory_desc_info_t src_md_info, dst_md_info;
637 compute::dispatch_t dispatch;
638 offsets_t off;
639 attr_info_t attr_info;
640
641 // Used by gen9 implementation
642 int initial_hwd_dim, initial_hwd_chunk_size;
643 int final_hwd_dim, final_hwd_chunk_size;
644 int initial_c_chunks, final_c_dim, final_c_chunk_size;
645 int initial_n_chunk_size, initial_n_chunks;
646 int final_n_dim, final_n_chunk_size;
647 bool skip_final_phase;
648 int c_block_size, n_block_size;
649 int vector_size;
650 int sub_group_size;
651 compute::dispatch_t finalize_dispatch;
652
653 // Used by combined implementation
654 int outer_dim_size, inner_dim_size, gws_inner_dim_size;
655 std::vector<reduction_phase_t> phases;
656 int inner_dim_per_sg;
657 size_t sp_size[2];
658};
659
660// Reorder
661enum reorder_kernel_t {
662 none,
663 dense_vector,
664 unroll_16b,
665 unroll_16b16c,
666 unroll_16a16b,
667 plain_to_ABcd84a42b,
668 vectorize_last_dim,
669 plain_to_ABxx8ayb,
670 plain_xFxE_to_abcdef,
671 transpose8x8,
672 transpose16x16,
673 local8x8,
674 local16x16,
675 reorder_nchw,
676 unaligned_sizes,
677 reorder_alt,
678 vectorize_groups,
679 pad_innermost,
680 xb_to_xab_xba
681};
682
683// Resampling
684struct resampling_conf_t {
685 dim_t ndims;
686 offsets_t off;
687 dim_t MB, C;
688 dim_t ID, IH, IW;
689 dim_t OD, OH, OW;
690 float FD, FH, FW;
691 dim_t vect_size;
692 dims_t padded_strides;
693 size_t lws[3], gws[3];
694 int sub_group_size;
695 dim_t padded_c;
696 attr_info_t attr_info;
697 compute::dispatch_t dispatch;
698};
699
700struct block_desc_t {
701 int dim_idx;
702 int blk_size;
703 int step_size;
704};
705
706#define LOOP_NEST_LEVEL 4
707struct vectorize_last_dim_t {
708 int vector_dim;
709 int rescale_coeff;
710 // composition of data within 16-item packet
711 block_desc_t src_vct[LOOP_NEST_LEVEL];
712 block_desc_t dst_vct[LOOP_NEST_LEVEL];
713 // dimensions to loop over when accessing packets defined above
714 block_desc_t src_blk[LOOP_NEST_LEVEL];
715 block_desc_t dst_blk[LOOP_NEST_LEVEL];
716 int src_blk_limits[MAX_NDIMS];
717 int dst_blk_limits[MAX_NDIMS];
718 int src_vect_limit;
719 int dst_vect_limit;
720};
721
722struct vectorize_group_t {
723 int vector_dim;
724 int src_loop_dim;
725 int dst_loop_dim;
726 int group_size;
727 int innermost_size;
728};
729
730struct xb_to_xab_xba_t {
731 int vd;
732 int blk_size;
733 int src_blk_dim;
734 int src_blk_coeff;
735 int dst_blk_dim;
736 int dst_blk_coeff;
737};
738
739union reorder_implementation {
740 vectorize_group_t vg;
741 xb_to_xab_xba_t ab;
742 vectorize_last_dim_t vld;
743};
744
745class scales_query_t {
746public:
747 bool has_default_values() const { return scales_.has_default_values(); }
748 int get_mask() const { return scales_.mask_; }
749 size_t get_count() const { return count_; }
750 memory_storage_t &get_scales(const exec_ctx_t &ctx) const {
751 return CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | arg_);
752 }
753
754 scales_query_t() = default;
755 scales_query_t(const primitive_attr_t *attr, const memory_desc_wrapper &mdw,
756 int arg)
757 : arg_(arg) {
758 scales_ = attr->scales_.get(arg);
759 count_ = get_attr_oscales_count(scales_.mask_, mdw);
760 }
761
762private:
763 runtime_scales_t scales_;
764 dim_t count_ = 0;
765 int arg_ = 0;
766};
767
768class zero_points_query_t {
769public:
770 bool has_default_values() const { return zps_.has_default_values(arg_); }
771 int get_mask() const {
772 int mask;
773 zps_.get(arg_, &mask);
774 return mask;
775 }
776 size_t get_count() const { return count_; }
777 memory_storage_t &get_zero_points(const exec_ctx_t &ctx) const {
778 return CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | arg_);
779 }
780
781 zero_points_query_t() = default;
782 zero_points_query_t(const primitive_attr_t *attr,
783 const memory_desc_wrapper &mdw, int arg)
784 : arg_(arg) {
785 zps_ = attr->zero_points_;
786 int mask;
787 zps_.get(arg, &mask);
788 count_ = get_attr_oscales_count(mask, mdw);
789 }
790
791private:
792 zero_points_t zps_;
793 dim_t count_;
794 int arg_;
795};
796
797struct quantization_t {
798public:
799 bool with_scale() const { return !scale_.has_default_values(); }
800 int scale_mask() const { return scale_.get_mask(); }
801 size_t num_scales() const { return scale_.get_count(); }
802 memory_storage_t &scales(const exec_ctx_t &ctx) const {
803 return scale_.get_scales(ctx);
804 }
805
806 bool with_zp() const { return !zp_.has_default_values(); }
807 int zp_mask() const { return zp_.get_mask(); }
808 size_t num_zps() const { return zp_.get_count(); }
809 memory_storage_t &zero_points(const exec_ctx_t &ctx) const {
810 return zp_.get_zero_points(ctx);
811 }
812
813 void define_macros(
814 compute::kernel_ctx_t &kernel_ctx, const std::string &name) const {
815 if (with_scale()) {
816 kernel_ctx.define_int("WITH_" + name + "_SCALE", 1);
817 kernel_ctx.define_int(name + "_SCALE_MASK", scale_mask());
818 kernel_ctx.define_int(name + "_NUM_SCALES", num_scales());
819 }
820
821 if (with_zp()) {
822 kernel_ctx.define_int("WITH_" + name + "_ZPOINT", 1);
823 kernel_ctx.define_int(name + "_ZPOINT_MASK", zp_mask());
824 kernel_ctx.define_int(name + "_NUM_ZPOINTS", num_zps());
825 }
826 }
827
828 quantization_t(const primitive_attr_t *attr, const memory_desc_wrapper &mdw,
829 int arg)
830 : scale_(attr, mdw, arg), zp_(attr, mdw, arg) {}
831 quantization_t() = default;
832
833private:
834 scales_query_t scale_;
835 zero_points_query_t zp_;
836};
837
838struct sum_quantization_t {
839public:
840 bool with_scale() const { return scale_ != 0; }
841 int scale_mask() const { return 0; }
842 size_t num_scales() const { return (size_t)(with_scale()); }
843 float scales() const { return scale_; }
844
845 bool with_zp() const { return zp_ != 0; }
846 int zp_mask() const { return 0; }
847 size_t num_zps() const { return (size_t)(with_zp()); }
848 int zero_points() const { return zp_; }
849
850 void define_macros(
851 compute::kernel_ctx_t &kernel_ctx, const std::string &name) const {
852 if (with_scale()) kernel_ctx.define_int("WITH_" + name + "_SCALE", 1);
853 if (with_zp()) kernel_ctx.define_int("WITH_" + name + "_ZPOINT", 1);
854 }
855
856 sum_quantization_t(const primitive_attr_t *attr) {
857 const auto &post_ops = attr->post_ops_;
858 const int sum_idx = post_ops.find(primitive_kind::sum);
859 if (sum_idx != -1) {
860 const auto &sum = post_ops.entry_[sum_idx].sum;
861 scale_ = sum.scale;
862 zp_ = sum.zero_point;
863 }
864 }
865 sum_quantization_t() = default;
866
867private:
868 float scale_ = 0;
869 int zp_ = 0;
870};
871
872struct reorder_conf_t {
873 bool has_padding;
874
875 quantization_t src_quant, dst_quant;
876 sum_quantization_t sum_quant;
877
878 reorder_kernel_t implementation;
879 int ndims;
880 size_t nelems;
881
882 compute::dispatch_t dispatch;
883
884 int sub_group_size;
885 memory_desc_info_t src_md_info;
886 memory_desc_info_t dst_md_info;
887
888 reorder_implementation aux_data;
889};
890
891// Concat
892struct concat_conf_t {
893 dim_t dst_extern_dim_size;
894 dim_t src_extern_dim_sizes[64];
895 dim_t offset[64];
896 dim_t inner_axis;
897 dim_t dst_offset0;
898 int block;
899 int n;
900 int simd;
901 int data_type_size;
902 size_t gws_d[3], lws_d[3];
903
904 data_type_t src_type, dst_type;
905 compute::dispatch_t dispatch;
906 int ndims;
907 memory_desc_info_t src_md_infos[64];
908 memory_desc_info_t dst_md_info;
909 int concat_axis;
910 int sub_group_size;
911 int iter_dim_idx, iter_dim_chunk;
912};
913
914// Elementwise
915struct eltwise_conf_t {
916 int ndims;
917 int vector_size;
918 bool with_zero_padding;
919 data_type_t data_type;
920 alg_kind_t alg;
921 bool is_forward;
922 int work_group_size;
923 int sub_group_size;
924 compute::dispatch_t dispatch;
925 memory_desc_info_t data_md_info;
926 memory_desc_info_t data_diff_md_info;
927
928 attr_info_t attr_info;
929};
930
931// Shuffle
932struct shuffle_conf_t {
933 data_type_t data_type;
934 int axis;
935 int transpose_row;
936 int transpose_col;
937 compute::dispatch_t dispatch;
938 memory_desc_info_t src_md_info;
939 memory_desc_info_t dst_md_info;
940};
941
942inline void set_default_pool_conf(pool_conf_t &conf, const pooling_desc_t &desc,
943 const memory_desc_t &src_md, const memory_desc_t &dst_md,
944 const primitive_attr_t &attr) {
945 const memory_desc_wrapper src_mdw(src_md);
946 const memory_desc_wrapper dst_mdw(dst_md);
947
948 const auto &src_dims = src_mdw.dims();
949 const auto &dst_dims = dst_mdw.dims();
950
951 int ndims = src_mdw.ndims();
952 conf.ndims = ndims;
953
954 conf.mb = src_dims[0];
955
956 conf.c = src_dims[1];
957 conf.mb_padded = src_mdw.padded_dims()[0];
958 conf.c_padded = src_mdw.padded_dims()[1];
959 conf.id = (ndims == 5) ? src_dims[2] : 1;
960 conf.ih = (ndims == 3) ? 1 : src_dims[ndims - 2];
961 conf.iw = src_dims[ndims - 1];
962 conf.od = (ndims == 5) ? dst_dims[2] : 1;
963 conf.oh = (ndims == 3) ? 1 : dst_dims[ndims - 2];
964 conf.ow = dst_dims[ndims - 1];
965
966 conf.stride_d = (ndims == 5) ? desc.strides[0] : 1;
967 conf.stride_h = (ndims == 3) ? 1 : desc.strides[ndims - 4];
968 conf.stride_w = desc.strides[ndims - 3];
969 conf.kd = (ndims == 5) ? desc.kernel[0] : 1;
970 conf.kh = (ndims == 3) ? 1 : desc.kernel[ndims - 4];
971 conf.kw = desc.kernel[ndims - 3];
972
973 conf.dd = (ndims == 5) ? desc.dilation[0] : 0;
974 conf.dh = (ndims == 3) ? 0 : desc.dilation[ndims - 4];
975 conf.dw = desc.dilation[ndims - 3];
976
977 conf.f_pad = (ndims == 5) ? desc.padding[0][0] : 0;
978 conf.t_pad = (ndims == 3) ? 0 : desc.padding[0][ndims - 4];
979 conf.l_pad = desc.padding[0][ndims - 3];
980
981 conf.alg = desc.alg_kind;
982
983 conf.src_dt = src_mdw.data_type();
984 conf.dst_dt = dst_mdw.data_type();
985
986 conf.src_md_info = memory_desc_info_t::create(src_mdw);
987 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
988
989 conf.is_training = desc.prop_kind == prop_kind::forward_training;
990 conf.is_backward = desc.prop_kind == prop_kind::backward_data;
991
992 conf.attr_info = attr_info_t::create(&attr);
993}
994
995inline void set_default_conf(conv_conf_t &conf, const convolution_desc_t &cd,
996 const memory_desc_t &src_md, const memory_desc_t &weights_md,
997 const memory_desc_t &dst_md, const memory_desc_t &bias_md,
998 const primitive_attr_t &attr) {
999
1000 const memory_desc_wrapper src_mdw(&src_md);
1001 const memory_desc_wrapper weights_mdw(&weights_md);
1002 const memory_desc_wrapper dst_mdw(&dst_md);
1003 const memory_desc_wrapper bias_mdw(&bias_md);
1004
1005 const bool with_groups = weights_mdw.ndims() == src_mdw.ndims() + 1;
1006 int ndims = src_mdw.ndims();
1007
1008 conf = utils::zero<decltype(conf)>();
1009 conf.with_groups = with_groups;
1010 conf.ndims = ndims;
1011 conf.prop_kind = cd.prop_kind;
1012 conf.ngroups = with_groups ? weights_mdw.dims()[0] : 1;
1013 conf.mb = src_mdw.dims()[0];
1014 conf.oc_without_padding = dst_mdw.dims()[1] / conf.ngroups;
1015 conf.ic_without_padding = src_mdw.dims()[1] / conf.ngroups;
1016 conf.id = (ndims == 5) ? src_mdw.dims()[2] : 1;
1017 conf.ih = (ndims == 3) ? 1 : src_mdw.dims()[ndims - 2];
1018 conf.iw = src_mdw.dims()[ndims - 1];
1019 conf.od = (ndims == 5) ? dst_mdw.dims()[2] : 1;
1020 conf.oh = (ndims == 3) ? 1 : dst_mdw.dims()[ndims - 2];
1021 conf.ow = dst_mdw.dims()[ndims - 1];
1022 conf.kd = (ndims == 5) ? weights_mdw.dims()[with_groups + 2] : 1;
1023 conf.kh = (ndims == 3) ? 1 : weights_mdw.dims()[with_groups + ndims - 2];
1024 conf.kw = weights_mdw.dims()[with_groups + ndims - 1];
1025
1026 conf.is_depthwise = conf.with_groups && conf.oc_without_padding == 1
1027 && conf.ic_without_padding == 1;
1028 conf.oc = dst_mdw.dims()[1] / conf.ngroups;
1029 conf.ic = src_mdw.dims()[1] / conf.ngroups;
1030
1031 conf.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1032 conf.back_pad = (ndims == 5) ? cd.padding[1][0] : 0;
1033 conf.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1034 conf.b_pad = (ndims == 3) ? 0 : cd.padding[1][ndims - 4];
1035 conf.l_pad = cd.padding[0][ndims - 3];
1036 conf.r_pad = cd.padding[1][ndims - 3];
1037 conf.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1038 conf.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1039 conf.stride_w = cd.strides[ndims - 3];
1040 conf.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1041 conf.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1042 conf.dilate_w = cd.dilates[ndims - 3];
1043
1044 conf.with_bias = bias_mdw.format_kind() != format_kind::undef;
1045
1046 conf.src_data_type = src_mdw.data_type();
1047 conf.weights_data_type = weights_mdw.data_type();
1048 conf.dst_data_type = dst_mdw.data_type();
1049
1050 conf.acc_data_type = cd.accum_data_type;
1051 conf.bias_data_type
1052 = conf.with_bias ? bias_mdw.data_type() : data_type::f32;
1053
1054 if (!src_mdw.format_any())
1055 conf.src_md_info = memory_desc_info_t::create(src_mdw);
1056 if (!weights_mdw.format_any())
1057 conf.wei_md_info = memory_desc_info_t::create(weights_mdw);
1058 if (!dst_mdw.format_any())
1059 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
1060
1061 conf.attr_info = attr_info_t::create(&attr);
1062}
1063
1064inline void set_offsets(compute::kernel_ctx_t &kernel_ctx,
1065 const memory_desc_wrapper &md, const char *str) {
1066 dim_t block_dims[DNNL_MAX_NDIMS];
1067 dim_t strides_compat[2][DNNL_MAX_NDIMS];
1068
1069 md.compute_blocks(block_dims);
1070 md.compute_strides_compat(strides_compat);
1071
1072 for (int d = 0; d < MAX_NDIMS; ++d) {
1073 const int block = block_dims[d];
1074
1075 kernel_ctx.define_int(
1076 utils::format("%s_B%d", str, d), (d < md.ndims()) ? block : 1);
1077 kernel_ctx.define_int(utils::format("%s_S%d", str, d),
1078 (d < md.ndims()) ? strides_compat[0][d] : 0);
1079 kernel_ctx.define_int(utils::format("%s_SB%d", str, d),
1080 (d < md.ndims()) ? strides_compat[1][d] : 0);
1081 }
1082
1083 kernel_ctx.define_int(utils::format("%s_OFFSET_PAD", str), md.md_->offset0);
1084}
1085
1086inline void set_offsets(const memory_desc_wrapper &md, int offs[4][MAX_NDIMS]) {
1087 dim_t block_dims[DNNL_MAX_NDIMS];
1088 dim_t strides_compat[2][DNNL_MAX_NDIMS];
1089
1090 md.compute_blocks(block_dims);
1091 md.compute_strides_compat(strides_compat);
1092 const dims_t &dims = md.dims();
1093
1094 for (int d = 0; d < md.ndims(); ++d) {
1095 const int block = block_dims[d];
1096
1097 offs[0][d] = block;
1098 offs[1][d] = strides_compat[0][d];
1099 offs[2][d] = strides_compat[1][d];
1100 offs[3][d] = dims[d];
1101 }
1102}
1103
1104inline void def_offsets(const int offs[4][MAX_NDIMS],
1105 compute::kernel_ctx_t &kernel_ctx, const char *str, const int ndims) {
1106
1107 for (int d = 0; d < MAX_NDIMS; d++) {
1108 kernel_ctx.define_int(
1109 utils::format("%s_B%d", str, d), (d < ndims) ? offs[0][d] : 1);
1110 kernel_ctx.define_int(
1111 utils::format("%s_S%d", str, d), (d < ndims) ? offs[1][d] : 0);
1112 kernel_ctx.define_int(
1113 utils::format("%s_SB%d", str, d), (d < ndims) ? offs[2][d] : 0);
1114 kernel_ctx.define_int(
1115 utils::format("%s_D%d", str, d), (d < ndims) ? offs[3][d] : 0);
1116 }
1117}
1118
1119inline void def_data_type(
1120 compute::kernel_ctx_t &kernel_ctx, data_type_t dt, const char *str) {
1121 switch (dt) {
1122 case data_type::bf16:
1123 kernel_ctx.add_option(
1124 utils::format("-D%s_DATA_T=ushort -D%s_DT_BF16", str, str));
1125 break;
1126 case data_type::f16:
1127 kernel_ctx.add_option(
1128 utils::format("-D%s_DATA_T=half -D%s_DT_F16", str, str));
1129 break;
1130 case data_type::f32:
1131 kernel_ctx.add_option(
1132 utils::format("-D%s_DATA_T=float -D%s_DT_F32", str, str));
1133 break;
1134 case data_type::f64:
1135 kernel_ctx.add_option(
1136 utils::format("-D%s_DATA_T=double -D%s_DT_F64", str, str));
1137 break;
1138 case data_type::s8:
1139 kernel_ctx.add_option(
1140 utils::format("-D%s_DATA_T=char -D%s_DT_S8", str, str));
1141 break;
1142 case data_type::u8:
1143 kernel_ctx.add_option(
1144 utils::format("-D%s_DATA_T=uchar -D%s_DT_U8", str, str));
1145 break;
1146 case data_type::s32:
1147 kernel_ctx.add_option(
1148 utils::format("-D%s_DATA_T=int -D%s_DT_S32", str, str));
1149 break;
1150 default: assert(!"unsupported data type"); break;
1151 }
1152}
1153
1154inline void def_memory_desc_info(compute::kernel_ctx_t &kernel_ctx,
1155 const memory_desc_info_t &md_info, const char *prefix) {
1156 def_data_type(kernel_ctx, md_info.data_type, prefix);
1157
1158 kernel_ctx.define_int(utils::format("%s_OFFSET0", prefix), md_info.offset0);
1159 kernel_ctx.define_int(utils::format("%s_NDIMS", prefix), md_info.ndims);
1160
1161 kernel_ctx.define_int(utils::format("%s_NLEVELS", prefix), md_info.nlevels);
1162
1163 for (int d = 0; d < MAX_NDIMS; ++d) {
1164 int dim = (d < md_info.ndims) ? md_info.dims[d] : 1;
1165 int padded_dim = (d < md_info.ndims) ? md_info.padded_dims[d] : 1;
1166 kernel_ctx.define_int(utils::format("%s_D%d", prefix, d), dim);
1167 kernel_ctx.define_int(utils::format("%s_PD%d", prefix, d), padded_dim);
1168
1169 for (int l = 0; l < md_info.nlevels + 1; ++l) {
1170 int block = (d < md_info.ndims) ? md_info.blocks[d][l] : 1;
1171 int stride = (d < md_info.ndims) ? md_info.strides[d][l] : 0;
1172 kernel_ctx.define_int(
1173 utils::format("%s_B%d_%d", prefix, d, l), block);
1174 kernel_ctx.define_int(
1175 utils::format("%s_S%d_%d", prefix, d, l), stride);
1176 }
1177 }
1178}
1179
1180inline void def_binary_alg_kinds(compute::kernel_ctx_t &kernel_ctx) {
1181 kernel_ctx.define_int("BINARY_ADD", alg_kind::binary_add);
1182 kernel_ctx.define_int("BINARY_MUL", alg_kind::binary_mul);
1183 kernel_ctx.define_int("BINARY_MIN", alg_kind::binary_min);
1184 kernel_ctx.define_int("BINARY_MAX", alg_kind::binary_max);
1185 kernel_ctx.define_int("BINARY_DIV", alg_kind::binary_div);
1186 kernel_ctx.define_int("BINARY_SUB", alg_kind::binary_sub);
1187 kernel_ctx.define_int("BINARY_GE", alg_kind::binary_ge);
1188 kernel_ctx.define_int("BINARY_GT", alg_kind::binary_gt);
1189 kernel_ctx.define_int("BINARY_LE", alg_kind::binary_le);
1190 kernel_ctx.define_int("BINARY_LT", alg_kind::binary_lt);
1191 kernel_ctx.define_int("BINARY_EQ", alg_kind::binary_eq);
1192 kernel_ctx.define_int("BINARY_NE", alg_kind::binary_ne);
1193}
1194
1195inline void def_eltwise_alg_kinds(compute::kernel_ctx_t &kernel_ctx) {
1196 kernel_ctx.define_int("RELU", alg_kind::eltwise_relu);
1197 kernel_ctx.define_int("LINEAR", alg_kind::eltwise_linear);
1198 kernel_ctx.define_int("SOFT_RELU", alg_kind::eltwise_soft_relu);
1199 kernel_ctx.define_int("MISH", alg_kind::eltwise_mish);
1200 kernel_ctx.define_int("LOGISTIC", alg_kind::eltwise_logistic);
1201 kernel_ctx.define_int("TANH", alg_kind::eltwise_tanh);
1202 kernel_ctx.define_int("ELU", alg_kind::eltwise_elu);
1203 kernel_ctx.define_int("SQUARE", alg_kind::eltwise_square);
1204 kernel_ctx.define_int("SQRT", alg_kind::eltwise_sqrt);
1205 kernel_ctx.define_int("ABS", alg_kind::eltwise_abs);
1206 kernel_ctx.define_int("EXP", alg_kind::eltwise_exp);
1207 kernel_ctx.define_int("GELU_TANH", alg_kind::eltwise_gelu_tanh);
1208 kernel_ctx.define_int("SWISH", alg_kind::eltwise_swish);
1209 kernel_ctx.define_int("LOG", alg_kind::eltwise_log);
1210 kernel_ctx.define_int("CLIP", alg_kind::eltwise_clip);
1211 kernel_ctx.define_int("CLIP_V2", alg_kind::eltwise_clip_v2);
1212 kernel_ctx.define_int("POW", alg_kind::eltwise_pow);
1213 kernel_ctx.define_int("GELU_ERF", alg_kind::eltwise_gelu_erf);
1214 kernel_ctx.define_int("ROUND", alg_kind::eltwise_round);
1215 kernel_ctx.define_int("HARDSWISH", alg_kind::eltwise_hardswish);
1216 kernel_ctx.define_int("HARDSIGMOID", alg_kind::eltwise_hardsigmoid);
1217
1218 kernel_ctx.define_int("RELU_DST", alg_kind::eltwise_relu_use_dst_for_bwd);
1219 kernel_ctx.define_int(
1220 "LOGISTIC_DST", alg_kind::eltwise_logistic_use_dst_for_bwd);
1221 kernel_ctx.define_int("TANH_DST", alg_kind::eltwise_tanh_use_dst_for_bwd);
1222 kernel_ctx.define_int("ELU_DST", alg_kind::eltwise_elu_use_dst_for_bwd);
1223 kernel_ctx.define_int("SQRT_DST", alg_kind::eltwise_sqrt_use_dst_for_bwd);
1224 kernel_ctx.define_int("EXP_DST", alg_kind::eltwise_exp_use_dst_for_bwd);
1225 kernel_ctx.define_int(
1226 "CLIP_V2_DST", alg_kind::eltwise_clip_v2_use_dst_for_bwd);
1227}
1228
1229inline bool post_ops_with_binary_ok(const primitive_attr_t *attr,
1230 const data_type_t dst_dt, const int max_ndims_supported = 2,
1231 const int prelu_mask_supported = 3) {
1232 const auto &p = attr->post_ops_;
1233
1234 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(false); };
1235 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
1236 auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); };
1237 auto is_prelu = [&](int idx) { return p.entry_[idx].is_prelu(); };
1238
1239 bool is_po_ok = true;
1240 for (int po_idx = 0; po_idx < p.len(); ++po_idx) {
1241 is_po_ok = is_po_ok
1242 && (is_eltwise(po_idx) || is_sum(po_idx) || is_binary(po_idx)
1243 || is_prelu(po_idx));
1244 if (is_binary(po_idx)) {
1245 const auto &bin_desc = p.entry_[po_idx].binary.src1_desc;
1246 if (bin_desc.ndims > max_ndims_supported) {
1247 // accept descriptor if unsupported dims are equal to 1.
1248 for (int dim_idx = max_ndims_supported;
1249 dim_idx < bin_desc.ndims; ++dim_idx) {
1250 if (bin_desc.dims[dim_idx] != 1) is_po_ok = false;
1251 }
1252 }
1253 }
1254 if (is_prelu(po_idx)) {
1255 if (p.entry_[po_idx].prelu.mask > prelu_mask_supported)
1256 is_po_ok = false;
1257 }
1258 if (is_sum(po_idx)) {
1259 if (p.entry_[po_idx].sum.zero_point != 0) return false;
1260 if (p.entry_[po_idx].sum.dt != dnnl_data_type_undef
1261 && types::data_type_size(p.entry_[po_idx].sum.dt)
1262 != types::data_type_size(dst_dt))
1263 return false;
1264 }
1265 }
1266
1267 if (p.len() > MAX_POST_OPS_SUPPORTED) is_po_ok = false;
1268
1269 return is_po_ok;
1270}
1271
1272inline void def_post_ops_cfg(compute::kernel_ctx_t &kernel_ctx,
1273 const post_ops_t &post_ops, const dnnl_dims_t *dst_dims) {
1274 const int po_nop_id = 0;
1275 const int po_binary_id = 1;
1276 const int po_eltwise_id = 2;
1277 const int po_sum_id = 3;
1278
1279 kernel_ctx.define_int("PO_BINARY", po_binary_id);
1280 kernel_ctx.define_int("PO_ELTWISE", po_eltwise_id);
1281 kernel_ctx.define_int("PO_SUM", po_sum_id);
1282
1283 std::string po_kernel_args = "-DPOST_OP_ARGS=\"";
1284 int nof_supported_post_ops = 0;
1285
1286 auto add_po_defines = [&](const std::string &bin_arg_name,
1287 const post_ops_t::entry_t &e, int idx) {
1288 if (e.is_binary()) {
1289 kernel_ctx.define_int(
1290 "PO_" + std::to_string(idx) + "_KIND", po_binary_id);
1291 kernel_ctx.define_int(
1292 "PO_" + std::to_string(idx) + "_ALG", e.binary.alg);
1293
1294 const memory_desc_wrapper src1_mdw(e.binary.src1_desc);
1295 const auto mdi = memory_desc_info_t::create(src1_mdw);
1296 def_memory_desc_info(kernel_ctx, mdi, bin_arg_name.c_str());
1297 if (mdi.data_type == data_type::bf16) {
1298 kernel_ctx.define_int(
1299 "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16", 1);
1300 } else {
1301 kernel_ctx.define_int(
1302 "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16", 0);
1303 }
1304 } else if (e.is_prelu()) {
1305 // binary && eltwise relu = prelu post op
1306 kernel_ctx.define_int(
1307 "PO_" + std::to_string(idx) + "_KIND", po_binary_id);
1308 kernel_ctx.define_int("PO_" + std::to_string(idx) + "_ALG",
1309 alg_kind_t::dnnl_eltwise_relu);
1310
1311 assert(dst_dims != nullptr);
1312
1313 memory_desc_t weight_mem_desc;
1314 dims_t weight_dims {};
1315 format_tag_t weights_tag;
1316 int weight_ndims = 0;
1317 if (e.prelu.mask == 0) {
1318 weight_ndims = 1;
1319 weight_dims[0] = 1;
1320 weights_tag = format_tag_t::dnnl_a;
1321 } else {
1322 // prelu weights are assumed to be upto 5 dims
1323 for (int d = 0; d < 5; ++d) {
1324 if (((e.prelu.mask >> d) & 0x1) == 1) {
1325 weight_ndims = d + 1;
1326 weight_dims[d] = (*dst_dims)[d];
1327 } else {
1328 weight_dims[d] = 1;
1329 }
1330 }
1331 switch (weight_ndims) {
1332 case 1: weights_tag = format_tag_t::dnnl_a; break;
1333 case 2: weights_tag = format_tag_t::dnnl_ab; break;
1334 case 3: weights_tag = format_tag_t::dnnl_acb; break;
1335 case 4: weights_tag = format_tag_t::dnnl_acdb; break;
1336 case 5: weights_tag = format_tag_t::dnnl_acdeb; break;
1337 default:
1338 weights_tag = format_tag_t::dnnl_format_tag_undef;
1339 break;
1340 }
1341 }
1342 memory_desc_init_by_tag(weight_mem_desc, weight_ndims, weight_dims,
1343 data_type_t::dnnl_f32, weights_tag);
1344 const memory_desc_wrapper weight_mdw(weight_mem_desc);
1345 const auto mdi = memory_desc_info_t::create(weight_mdw);
1346 def_memory_desc_info(kernel_ctx, mdi, bin_arg_name.c_str());
1347
1348 // prelu weights are assumed to be f32
1349 kernel_ctx.define_int(
1350 "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16", 0);
1351 } else {
1352 memory_desc_t empty_mem_desc;
1353 dnnl_dims_t empty_dims = {1, 1, 1, 1};
1354 memory_desc_init_by_tag(empty_mem_desc, 4, empty_dims,
1355 data_type_t::dnnl_s8, format_tag_t::dnnl_nchw);
1356 const memory_desc_wrapper src1_mdw(empty_mem_desc);
1357 const auto mdi = memory_desc_info_t::create(src1_mdw);
1358 def_memory_desc_info(kernel_ctx, mdi, bin_arg_name.c_str());
1359 kernel_ctx.define_int(
1360 "PO_" + std::to_string(idx) + "_BIN_ARG_DT_IS_BF16", 0);
1361 }
1362 if (e.is_eltwise(false)) {
1363 kernel_ctx.define_int(
1364 "PO_" + std::to_string(idx) + "_KIND", po_eltwise_id);
1365 kernel_ctx.define_int(
1366 "PO_" + std::to_string(idx) + "_ALG", e.eltwise.alg);
1367 kernel_ctx.define_float(
1368 ("PO_" + std::to_string(idx) + "_ELTWISE_ALPHA").c_str(),
1369 e.eltwise.alpha);
1370 kernel_ctx.define_float(
1371 ("PO_" + std::to_string(idx) + "_ELTWISE_BETA").c_str(),
1372 e.eltwise.beta);
1373 kernel_ctx.define_float(
1374 ("PO_" + std::to_string(idx) + "_ELTWISE_SCALE").c_str(),
1375 e.eltwise.scale);
1376 } else {
1377 kernel_ctx.define_float(
1378 ("PO_" + std::to_string(idx) + "_ELTWISE_ALPHA").c_str(),
1379 1.0f);
1380 kernel_ctx.define_float(
1381 ("PO_" + std::to_string(idx) + "_ELTWISE_BETA").c_str(),
1382 0.0f);
1383 kernel_ctx.define_float(
1384 ("PO_" + std::to_string(idx) + "_ELTWISE_SCALE").c_str(),
1385 1.0f);
1386 }
1387 if (e.is_sum(false)) {
1388 kernel_ctx.define_int(
1389 "PO_" + std::to_string(idx) + "_KIND", po_sum_id);
1390 kernel_ctx.define_int(
1391 "PO_" + std::to_string(idx) + "_ALG", alg_kind::undef);
1392 kernel_ctx.define_float(
1393 ("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(),
1394 e.sum.scale);
1395 } else {
1396 kernel_ctx.define_float(
1397 ("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(), 1.0f);
1398 }
1399 if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false)
1400 || e.is_prelu())) {
1401 // empty post op
1402 kernel_ctx.define_int(
1403 "PO_" + std::to_string(idx) + "_KIND", po_nop_id);
1404 // *_ALG need to be set but it's unused when kind is NOP
1405 kernel_ctx.define_int(
1406 "PO_" + std::to_string(idx) + "_ALG", alg_kind::undef);
1407 --nof_supported_post_ops;
1408 }
1409 po_kernel_args += ", const __global PO_" + std::to_string(idx)
1410 + "_BIN_ARG_DATA_T *po_" + std::to_string(idx) + "_binary_arg";
1411 };
1412
1413 for (int idx = 0; idx < post_ops.len(); ++idx, ++nof_supported_post_ops) {
1414 const std::string bin_arg_name
1415 = "PO_" + std::to_string(idx) + "_BIN_ARG";
1416 add_po_defines(bin_arg_name, post_ops.entry_[idx], idx);
1417 }
1418
1419 kernel_ctx.define_int("POST_OP_CHAIN_LENGTH", nof_supported_post_ops);
1420 if (post_ops.len() > 0) {
1421 // due to C macro limitations on which post op service is build always
1422 // load bf16 conversion functions
1423 kernel_ctx.define_int("POST_OP_USING_BF16", 1);
1424 }
1425 po_kernel_args += "\"";
1426 kernel_ctx.add_option(po_kernel_args);
1427}
1428
1429inline int append_post_ops_to_arg_list_base(const exec_args_t &args,
1430 compute::kernel_arg_list_t &arg_list, int post_op_idx,
1431 const post_ops_t &post_ops) {
1432 auto set_arg_entry = [&](const post_ops_t::entry_t &e, int po_idx) {
1433 if (e.is_binary()) {
1434 auto arg = args.at(
1435 DNNL_ARG_ATTR_MULTIPLE_POST_OP(po_idx) | DNNL_ARG_SRC_1);
1436 assert(arg.is_const);
1437
1438 auto &binary_arg = arg.mem
1439 ? *(arg.mem->memory_storage())
1440 : dnnl::impl::memory_storage_t::empty_storage();
1441 arg_list.set(post_op_idx++, binary_arg);
1442 } else if (e.is_prelu()) {
1443 auto arg = args.at(
1444 DNNL_ARG_ATTR_MULTIPLE_POST_OP(po_idx) | DNNL_ARG_WEIGHTS);
1445 assert(arg.is_const);
1446 auto &prelu_wei_arg = arg.mem
1447 ? *(arg.mem->memory_storage())
1448 : dnnl::impl::memory_storage_t::empty_storage();
1449 arg_list.set(post_op_idx++, prelu_wei_arg);
1450 } else {
1451 arg_list.set(post_op_idx++, memory_storage_t::empty_storage());
1452 }
1453 };
1454
1455 for (int idx = 0; idx < post_ops.len(); ++idx) {
1456 set_arg_entry(post_ops.entry_[idx], idx);
1457 }
1458 return post_op_idx;
1459}
1460inline int append_post_ops_to_arg_list_gemm(const exec_args_t &args,
1461 compute::kernel_arg_list_t &arg_list, int post_op_idx,
1462 const post_ops_t &post_ops) {
1463 return append_post_ops_to_arg_list_base(
1464 args, arg_list, post_op_idx, post_ops);
1465}
1466inline int append_post_ops_to_arg_list(const exec_ctx_t &ctx,
1467 compute::kernel_arg_list_t &arg_list, int post_op_idx,
1468 const post_ops_t &post_ops) {
1469 exec_args_t args;
1470 return append_post_ops_to_arg_list_base(
1471 ctx.args(), arg_list, post_op_idx, post_ops);
1472}
1473
1474inline bool post_ops_preserves_zeroes(
1475 const exec_ctx_t &ctx, const post_ops_t &post_ops) {
1476 bool preserve_zeroes = true;
1477 for (int idx = 0; idx < post_ops.len(); ++idx) {
1478 const post_ops_t::entry_t &po_entry = post_ops.entry_[idx];
1479 if (po_entry.is_binary()) {
1480 // only binary mul is preserving zeroes
1481 preserve_zeroes &= po_entry.binary.alg
1482 == dnnl::impl::alg_kind_t::dnnl_binary_mul;
1483 }
1484 if (po_entry.is_eltwise(false)) {
1485 preserve_zeroes &= gpu_eltwise_fwd_pd_t::eltwise_preserves_zero(
1486 po_entry.eltwise.alg, po_entry.eltwise.alpha,
1487 po_entry.eltwise.beta);
1488 }
1489 }
1490 return preserve_zeroes;
1491}
1492
1493inline void def_attr_info(compute::kernel_ctx_t &kernel_ctx,
1494 const attr_info_t &attr_info, const post_ops_t &post_ops,
1495 const dnnl_dims_t *dst_dims = nullptr) {
1496 assert(attr_info.initialized);
1497
1498 kernel_ctx.define_int("WITH_POST_OP", post_ops.len() > 0);
1499
1500 kernel_ctx.define_int("WITH_ELTWISE", attr_info.with_eltwise);
1501 kernel_ctx.define_int("ELTWISE_IDX", attr_info.eltwise_idx);
1502 kernel_ctx.define_int("ELTWISE_ALG", attr_info.eltwise_alg);
1503
1504 kernel_ctx.define_int("WITH_SUM", attr_info.with_sum);
1505 kernel_ctx.define_int("SUM_IDX", attr_info.sum_idx);
1506 kernel_ctx.define_int("SUM_SCALE", attr_info.sum_scale);
1507 kernel_ctx.define_int("SUM_SCALE1", attr_info.sum_scale == 1.0f);
1508
1509 kernel_ctx.define_int("WITH_SRC0_SCALE", attr_info.with_src0_scale);
1510 kernel_ctx.define_int("WITH_SRC1_SCALE", attr_info.with_src1_scale);
1511
1512 kernel_ctx.define_int("WITH_SCALES", attr_info.with_oscales);
1513 kernel_ctx.define_int(
1514 "WITH_RUNTIME_SCALES", attr_info.with_runtime_oscales);
1515 kernel_ctx.define_int("SCALES_PER_OC", attr_info.with_per_oc_oscales);
1516 kernel_ctx.define_int("SCALES_COMMON", attr_info.with_common_oscales);
1517
1518 kernel_ctx.define_int("WITH_SRC_SCALES", attr_info.with_src_scales);
1519 kernel_ctx.define_int("WITH_WEI_SCALES", attr_info.with_wei_scales);
1520 kernel_ctx.define_int("WITH_DST_SCALES", attr_info.with_dst_scales);
1521 kernel_ctx.define_int("WEI_SCALES_MASK", attr_info.wei_scales_mask);
1522
1523 kernel_ctx.define_int("WITH_SRC_ZPOINTS", attr_info.with_src_zpoints);
1524 kernel_ctx.define_int("WITH_WEI_ZPOINTS", attr_info.with_wei_zpoints);
1525 kernel_ctx.define_int("WITH_DST_ZPOINTS", attr_info.with_dst_zpoints);
1526 kernel_ctx.define_int(
1527 "WITH_SRC_ZPOINTS_PER_IC", attr_info.with_per_ic_src_zpoints);
1528 kernel_ctx.define_int(
1529 "WITH_DST_ZPOINTS_PER_OC", attr_info.with_per_oc_dst_zpoints);
1530
1531 def_binary_alg_kinds(kernel_ctx);
1532 def_eltwise_alg_kinds(kernel_ctx);
1533
1534 def_post_ops_cfg(kernel_ctx, post_ops, dst_dims);
1535}
1536
1537inline void def_dispatch(compute::kernel_ctx_t &kernel_ctx,
1538 const compute::dispatch_t &dispatch) {
1539 dispatch.def_kernel_macros(kernel_ctx);
1540}
1541
1542inline void maybe_fix_non_uniform_work_sizes(
1543 bool has_non_uniform_wg, conv_conf_t &conf) {
1544 for (int i = 0; i < 3; i++) {
1545 conf.gws_orig_d[i] = conf.gws_d[i];
1546 if (!has_non_uniform_wg)
1547 conf.gws_d[i] = utils::rnd_up(conf.gws_d[i], conf.lws_d[i]);
1548 }
1549}
1550
1551inline void bwd_w_compute_block_sizes(conv_conf_t &conf, engine_t *engine) {
1552 const bool is_1stconv = conf.ic_without_padding == 3;
1553
1554 if (conf.is_depthwise) {
1555 conf.odb = 1;
1556 conf.ohb = 1;
1557 conf.owb = utils::rnd_up(conf.ow, conf.ow_block);
1558 conf.ocb = 1;
1559 conf.icb = 1;
1560 conf.osp_chunk = utils::div_up(conf.od, conf.odb)
1561 * utils::div_up(conf.oh, conf.ohb)
1562 * utils::div_up(conf.ow, conf.owb);
1563
1564 conf.mb_chunk = utils::div_up(conf.mb, conf.mb_block);
1565 conf.nchunk = conf.osp_chunk * conf.mb_chunk;
1566 return;
1567 }
1568 auto *dev_info = utils::downcast<compute::compute_engine_t *>(engine)
1569 ->device_info();
1570 int hw_threads = dev_info->hw_threads();
1571 size_t llc_bytes = dev_info->llc_cache_size();
1572
1573 auto next_candidate = [](int size, int block) {
1574 if (size == block) return block;
1575 // If size is big enough, then do not care about the remainder.
1576 if (block * 16 < size) return block + 1;
1577 // Otherwise search for the next divisor.
1578 block++;
1579 while (size % block != 0)
1580 block++;
1581 return block;
1582 };
1583
1584 int mb_nb = 1;
1585 conf.odb = 1;
1586 conf.ohb = 1;
1587 conf.owb = 1;
1588
1589 int mb_nblk = utils::div_up(conf.mb, conf.mb_block);
1590 int ic_nblk = utils::div_up(conf.ic, conf.ic_block);
1591 int oc_nblk = utils::div_up(conf.oc, conf.oc_block);
1592
1593 int ic_nb_max = is_1stconv ? 1 : nstl::min(ic_nblk, 16);
1594 int oc_nb_max = nstl::min(oc_nblk, 16);
1595 int ic_nb = is_1stconv ? 1 : utils::max_div(ic_nblk, ic_nb_max);
1596 int oc_nb = utils::max_div(oc_nblk, oc_nb_max);
1597
1598 int mb_nb_max = 1;
1599 if (!is_1stconv && (conf.mb_block == 1) && (conf.ic % 1024 != 0)
1600 && (conf.oc % 1024 != 0)) {
1601 mb_nb_max = 4;
1602 }
1603
1604 auto get_nthr = [&]() {
1605 int nthr = utils::div_up(mb_nblk, mb_nb)
1606 * utils::div_up(conf.od, conf.odb)
1607 * utils::div_up(conf.oh, conf.ohb)
1608 * utils::div_up(conf.ow, conf.owb) * conf.kh * conf.kw * conf.kd
1609 * oc_nblk * (is_1stconv ? 1 : ic_nblk) * conf.ngroups;
1610 return nthr;
1611 };
1612
1613 auto get_src_dst_size = [&]() {
1614 int iwb = conf.ndims < 3 ? 1 : conf.owb + 2 * (conf.kw - 1);
1615 int ihb = conf.ndims < 4 ? 1 : conf.ohb + 2 * (conf.kh - 1);
1616 int idb = conf.ndims < 5 ? 1 : conf.odb + 2 * (conf.kd - 1);
1617
1618 size_t ispb = iwb * ihb * idb;
1619 size_t ospb = conf.owb * conf.ohb * conf.odb;
1620 size_t src_size = sizeof(float) * conf.mb_block
1621 * (is_1stconv ? conf.ic : ic_nb * conf.ic_block) * ispb;
1622 size_t dst_size = sizeof(float) * conf.mb_block
1623 * (oc_nb * conf.oc_block) * ospb;
1624
1625 int nthr_per_spb
1626 = conf.kh * conf.kw * conf.kd * ic_nb * oc_nb * conf.ngroups;
1627 size_t sz = (size_t)(src_size + dst_size);
1628 if (nthr_per_spb < hw_threads) sz = sz * hw_threads / nthr_per_spb;
1629 return sz;
1630 };
1631
1632 auto try_next = [&](int &v, int next) {
1633 if (next <= v) return false;
1634 int v_old = v;
1635 v = next;
1636 // Heuristics:
1637 // - src and dst size accessed in the inner loops should fit LLC
1638 // - Require at least (3 * hw_threads) to spawn to have enough
1639 // parallelism
1640 if (get_src_dst_size() > llc_bytes || get_nthr() < 3 * hw_threads) {
1641 v = v_old;
1642 return false;
1643 }
1644 return true;
1645 };
1646
1647 if (utils::one_of(conf.ver, ver_nhwc, ver_8ow16c, ver_1stconv))
1648 conf.owb = conf.ow_block;
1649
1650 // Increase spatial tile size as much as possible.
1651 for (int i = 0; i < 128; i++) {
1652 int owb_next;
1653 if (utils::one_of(conf.ver, ver_nhwc, ver_8ow16c, ver_1stconv)) {
1654 int ow_padded = utils::rnd_up(conf.ow, conf.ow_block);
1655 owb_next = conf.ow_block
1656 * next_candidate(ow_padded / conf.ow_block,
1657 conf.owb / conf.ow_block);
1658 } else {
1659 owb_next = next_candidate(conf.ow, conf.owb);
1660 }
1661 try_next(conf.owb, owb_next);
1662
1663 int ohb_next = next_candidate(conf.oh, conf.ohb);
1664 try_next(conf.ohb, ohb_next);
1665
1666 int odb_next = next_candidate(conf.od, conf.odb);
1667 try_next(conf.odb, odb_next);
1668
1669 int mb_nb_next = next_candidate(mb_nb_max, mb_nb);
1670 try_next(mb_nb, mb_nb_next);
1671 }
1672
1673 conf.icb = is_1stconv ? conf.ic : ic_nb * conf.ic_block;
1674 conf.ocb = oc_nb * conf.oc_block;
1675
1676 conf.osp_chunk = utils::div_up(conf.od, conf.odb)
1677 * utils::div_up(conf.oh, conf.ohb)
1678 * utils::div_up(conf.ow, conf.owb);
1679
1680 conf.mb_chunk = utils::div_up(mb_nblk, mb_nb);
1681
1682 conf.nchunk = conf.mb_chunk * conf.osp_chunk;
1683}
1684
1685} // namespace gpu
1686} // namespace impl
1687} // namespace dnnl
1688
1689#endif
1690