1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_CONVOLUTION_UTILS_HPP
18#define CPU_GEMM_CONVOLUTION_UTILS_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23
24#include "cpu/cpu_convolution_pd.hpp"
25#include "cpu/cpu_engine.hpp"
26#include "cpu/zero_point_utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32enum conv_gemm_loop_order_t { gemm_loop_rlb, gemm_loop_lrb, gemm_loop_lbr };
33struct conv_gemm_conf_t {
34 prop_kind_t prop_kind;
35
36 dim_t mb;
37 dim_t ngroups, ic, oc;
38 dim_t iw, ih, id, ow, oh, od;
39 dim_t l_pad, t_pad, f_pad, e_pad, b_pad, r_pad;
40 dim_t kh, kw, kd;
41 dim_t stride_h, stride_w, stride_d;
42 dim_t dilate_h, dilate_w, dilate_d;
43 bool with_bias;
44 bool with_eltwise;
45 bool with_binary;
46 bool with_sum;
47 post_ops_t post_ops;
48 bool is_nspc;
49
50 dim_t is, os, ks;
51 dim_t ic_block, oc_block;
52
53 int nthr;
54 ptrdiff_t im2col_sz;
55 bool need_wei_reduction;
56 bool signed_input;
57 dim_t oh_block;
58 dim_t ow_block;
59 dim_t os_block, os_nb_block;
60 bool outer_threading;
61 conv_gemm_loop_order_t loop_order;
62 int nthr_oc;
63
64 zero_point_config_t zp;
65
66 data_type_t bias_data_type;
67 data_type_t dst_data_type;
68 data_type_t sum_data_type;
69 size_t dst_os_stride;
70 size_t scale_idx_mult;
71 bool with_dst_scale;
72};
73
74struct single_gemm_conv_chunk_desc_t {
75 single_gemm_conv_chunk_desc_t() = default;
76 single_gemm_conv_chunk_desc_t(dim_t d_off, dim_t d_size, dim_t h_off,
77 dim_t h_size, dim_t w_off, dim_t w_size);
78
79 dim_t d_off_ = 0;
80 dim_t d_size_ = 0;
81 dim_t h_off_ = 0;
82 dim_t h_size_ = 0;
83 dim_t w_off_ = 0;
84 dim_t w_size_ = 0;
85};
86
87namespace jit_gemm_convolution_utils {
88template <typename data_type_t>
89void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im,
90 data_type_t *col, dim_t od, int spatial_step, int spatial_block);
91
92template <typename T>
93void transpose_dt(const conv_gemm_conf_t &jcp, const T *__restrict im,
94 T *__restrict imtr);
95
96template <typename im_dt, typename col_dt>
97void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict im,
98 col_dt *__restrict col, dim_t od);
99
100template <typename data_type_t>
101void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im,
102 data_type_t *__restrict col, dim_t ss, dim_t sb, dim_t cs, dim_t cb);
103
104template <typename im_dt, typename col_dt>
105void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict im,
106 void *__restrict imtr, col_dt *__restrict col, dim_t hs, dim_t hb,
107 dim_t ws, dim_t wb);
108
109template <typename T>
110void col2im_dt(
111 const conv_gemm_conf_t &jcp, const T *__restrict col, T *__restrict im);
112void col2im_3d(const conv_gemm_conf_t &jcp, const float *col, float *im,
113 dim_t od, int spatial_step, int spatial_block);
114void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im,
115 int spatial_step, int spatial_block);
116
117status_t init_conf(conv_gemm_conf_t &jcp,
118 memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
119 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
120 memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads);
121
122void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
123 int &nthr_g, int &ithr_mb, int &nthr_mb);
124void bwd_weights_reduction_par_ncsp(int ithr, int nthr,
125 const conv_gemm_conf_t &jcp, const float *weights_reduce_ws,
126 float *weights);
127void bwd_weights_reduction_par_nspc(int ithr, int nthr, size_t g_start,
128 size_t g_end, const conv_gemm_conf_t &jcp,
129 const float *weights_reduce_base, float *diff_weights);
130
131bool padding_exists(const conv_gemm_conf_t &jcp) noexcept;
132
133} // namespace jit_gemm_convolution_utils
134
135} // namespace cpu
136} // namespace impl
137} // namespace dnnl
138
139#endif
140