1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_GEMM_GEMM_INFO_HPP
18#define CPU_X64_GEMM_GEMM_INFO_HPP
19
20#include <cstdint>
21#include <memory>
22
23#include "common/c_types_map.hpp"
24
25#include "cpu/x64/gemm/gemm_pack_storage.hpp"
26#include "cpu/x64/gemm/gemm_threading.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33enum class pack_type { none, pack_a, pack_b };
34
35enum class offset_type {
36 none,
37 fixed,
38 column,
39 row,
40};
41
42// Indices for kernel arrays. TODO Is it okay to place this here?
43enum { no_sum = 0, do_sum = 1 };
44enum { no_trans = 0, do_trans = 1, packed = 2 };
45enum { no_beta0 = 0, do_beta0 = 1 };
46enum { no_alpha1 = 0, do_alpha1 = 1 };
47
48template <typename a_t, typename b_t, typename c_t>
49struct gemm_info_t {
50
51 // Interface arguments.
52 int transa, transb;
53 offset_type offsetc;
54 dim_t m, n, k;
55 dim_t lda, ldb, ldc;
56 const a_t *a;
57 const b_t *b;
58 c_t *c;
59 float alpha, beta;
60
61 int32_t ao;
62 int32_t bo;
63 const c_t *co;
64
65 pack_type packing;
66 gemm_pack_storage_t *pack_dst;
67 bool measure_only;
68 std::shared_ptr<const gemm_pack_storage_t> a_packed, b_packed;
69
70 // Kernel parameters.
71 dim_t um, un, uk, bm, bn, bk;
72 dim_t bn_small_k, bk_traditional, blocking_small_k;
73
74 // Gemv parameters
75 int swap;
76
77 using copy_a_fptr_t = void (*)(const dim_t *m, const dim_t *n,
78 const a_t *src, const dim_t *ldsrc, const float *alpha, a_t *dst,
79 const dim_t *dummy1, const dim_t *dummy2, c_t *row_col_sum);
80
81 using copy_b_fptr_t = void (*)(const dim_t *m, const dim_t *n,
82 const b_t *src, const dim_t *ldsrc, const float *alpha, b_t *dst,
83 const dim_t *dummy1, const dim_t *dummy2, c_t *row_col_sum);
84
85 using gemm_fptr_t = void (*)(const dim_t *, const dim_t *, const dim_t *,
86 const float *, const a_t *, const b_t *, c_t *, const dim_t,
87 const c_t *, const c_t *);
88
89 using gemv_fptr_t = void (*)(const dim_t *, const dim_t *, const float *,
90 const a_t *, const dim_t *, const b_t *, const dim_t *, c_t *,
91 const dim_t *);
92
93 using gemv_s8s8s32_fptr_t
94 = void (*)(const dim_t, const dim_t, const float, const int8_t *,
95 const dim_t, const int8_t *, const float, int32_t *);
96
97 using gemv_s8u8s32_fptr_t
98 = void (*)(const dim_t, const dim_t, const float, const int8_t *,
99 const dim_t, const uint8_t *, const float, int32_t *);
100
101 using gemv_u8s8s32_fptr_t
102 = void (*)(const dim_t, const dim_t, const float, const uint8_t *,
103 const dim_t, const int8_t *, const float, int32_t *);
104
105 // gemm kernels
106 copy_a_fptr_t copyA = nullptr;
107 copy_b_fptr_t copyB = nullptr;
108 gemm_fptr_t kernel[2][2][2] = {{{nullptr}}};
109
110 // gemv kernels
111 gemv_fptr_t gemv_kernel[2] = {nullptr};
112 gemv_s8s8s32_fptr_t gemv_s8s8s32_kernel = nullptr;
113 gemv_s8u8s32_fptr_t gemv_s8u8s32_kernel = nullptr;
114 gemv_u8s8s32_fptr_t gemv_u8s8s32_kernel = nullptr;
115
116 // copyA[trans][sum]
117 static copy_a_fptr_t copy_a_kern[2][2];
118
119 // copyB[trans][sum]
120 static copy_b_fptr_t copy_b_kern[2][2];
121
122 // kern[beta0][alpha1][col_off][row_off]
123 static gemm_fptr_t kern[2][2][2][2];
124
125 // gemv_kern[trans]
126 static gemv_fptr_t gemv_kern[2];
127
128 static gemv_s8s8s32_fptr_t gemv_s8s8s32_kern;
129 static gemv_s8u8s32_fptr_t gemv_s8u8s32_kern;
130 static gemv_u8s8s32_fptr_t gemv_u8s8s32_kern;
131
132 template <bool is_trans>
133 static void copy_a_sum_ref(const dim_t *p_k, const dim_t *p_m,
134 const a_t *src, const dim_t *p_ld, const float *p_alpha, a_t *dst,
135 const dim_t *dummy1, const dim_t *dummy2, c_t *a_row_sum) {
136
137 copy_a_kern[is_trans][no_sum](
138 p_k, p_m, src, p_ld, p_alpha, dst, dummy1, dummy2, a_row_sum);
139
140 dim_t k = *p_k;
141 dim_t m = *p_m;
142 dim_t ld = *p_ld;
143
144 // Calculate op(A) row sum.
145 if (!is_trans) {
146 PRAGMA_OMP_SIMD()
147 for (dim_t i = 0; i < m; i++)
148 a_row_sum[i] = 0;
149
150 for (dim_t j = 0; j < k; j++) {
151 PRAGMA_OMP_SIMD()
152 for (dim_t i = 0; i < m; i++) {
153 a_row_sum[i] += src[i + j * ld];
154 }
155 }
156 } else {
157 for (dim_t i = 0; i < m; i++) {
158 c_t acc = 0;
159
160 PRAGMA_OMP_SIMD(reduction(+ : acc))
161 for (dim_t j = 0; j < k; j++) {
162 acc += src[j + i * ld];
163 }
164
165 a_row_sum[i] = acc;
166 }
167 }
168 }
169
170 template <bool is_trans>
171 static void copy_b_sum_ref(const dim_t *p_k, const dim_t *p_n,
172 const b_t *src, const dim_t *p_ld, const float *alpha, b_t *dst,
173 const dim_t *dummy1, const dim_t *dummy2, c_t *b_col_sum) {
174
175 copy_b_kern[is_trans][no_sum](
176 p_k, p_n, src, p_ld, alpha, dst, dummy1, dummy2, b_col_sum);
177
178 dim_t k = *p_k;
179 dim_t n = *p_n;
180 dim_t ld = *p_ld;
181
182 // Calculate op(B) column sum.
183 if (!is_trans) {
184 for (dim_t j = 0; j < n; j++) {
185 c_t acc = 0;
186
187 PRAGMA_OMP_SIMD(reduction(+ : acc))
188 for (dim_t i = 0; i < k; i++)
189 acc += src[i + j * ld];
190
191 b_col_sum[j] = acc;
192 }
193 } else {
194 PRAGMA_OMP_SIMD()
195 for (dim_t j = 0; j < n; j++)
196 b_col_sum[j] = 0;
197
198 for (dim_t i = 0; i < k; i++) {
199 PRAGMA_OMP_SIMD()
200 for (dim_t j = 0; j < n; j++)
201 b_col_sum[j] += src[j + i * ld];
202 }
203 }
204 }
205
206 bool force_nocopy;
207
208 gemm_info_t(const char *transA, const char *transB, const char *offsetC,
209 const dim_t *m, const dim_t *n, const dim_t *k, const float *alpha,
210 const a_t *a, const dim_t *lda, const a_t *oa, const b_t *b,
211 const dim_t *ldb, const b_t *ob, const float *beta, c_t *c,
212 const dim_t *ldc, const c_t *oc, bool force_nocopy,
213 pack_type packing, gemm_pack_storage_t *pack_dst,
214 bool measure_only);
215
216 bool hasKernels(void);
217
218 void update_blocking(const gemm_threading_t &thread_info);
219
220private:
221 void jit_init(void);
222};
223
224} // namespace x64
225} // namespace cpu
226} // namespace impl
227} // namespace dnnl
228
229#endif // CPU_X64_GEMM_GEMM_INFO_HPP
230