1 | /******************************************************************************* |
2 | * Copyright 2019-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_GEMM_GEMM_INFO_HPP |
18 | #define CPU_X64_GEMM_GEMM_INFO_HPP |
19 | |
20 | #include <cstdint> |
21 | #include <memory> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | |
25 | #include "cpu/x64/gemm/gemm_pack_storage.hpp" |
26 | #include "cpu/x64/gemm/gemm_threading.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | enum class pack_type { none, pack_a, pack_b }; |
34 | |
35 | enum class offset_type { |
36 | none, |
37 | fixed, |
38 | column, |
39 | row, |
40 | }; |
41 | |
42 | // Indices for kernel arrays. TODO Is it okay to place this here? |
43 | enum { no_sum = 0, do_sum = 1 }; |
44 | enum { no_trans = 0, do_trans = 1, packed = 2 }; |
45 | enum { no_beta0 = 0, do_beta0 = 1 }; |
46 | enum { no_alpha1 = 0, do_alpha1 = 1 }; |
47 | |
48 | template <typename a_t, typename b_t, typename c_t> |
49 | struct gemm_info_t { |
50 | |
51 | // Interface arguments. |
52 | int transa, transb; |
53 | offset_type offsetc; |
54 | dim_t m, n, k; |
55 | dim_t lda, ldb, ldc; |
56 | const a_t *a; |
57 | const b_t *b; |
58 | c_t *c; |
59 | float alpha, beta; |
60 | |
61 | int32_t ao; |
62 | int32_t bo; |
63 | const c_t *co; |
64 | |
65 | pack_type packing; |
66 | gemm_pack_storage_t *pack_dst; |
67 | bool measure_only; |
68 | std::shared_ptr<const gemm_pack_storage_t> a_packed, b_packed; |
69 | |
70 | // Kernel parameters. |
71 | dim_t um, un, uk, bm, bn, bk; |
72 | dim_t bn_small_k, bk_traditional, blocking_small_k; |
73 | |
74 | // Gemv parameters |
75 | int swap; |
76 | |
77 | using copy_a_fptr_t = void (*)(const dim_t *m, const dim_t *n, |
78 | const a_t *src, const dim_t *ldsrc, const float *alpha, a_t *dst, |
79 | const dim_t *dummy1, const dim_t *dummy2, c_t *row_col_sum); |
80 | |
81 | using copy_b_fptr_t = void (*)(const dim_t *m, const dim_t *n, |
82 | const b_t *src, const dim_t *ldsrc, const float *alpha, b_t *dst, |
83 | const dim_t *dummy1, const dim_t *dummy2, c_t *row_col_sum); |
84 | |
85 | using gemm_fptr_t = void (*)(const dim_t *, const dim_t *, const dim_t *, |
86 | const float *, const a_t *, const b_t *, c_t *, const dim_t, |
87 | const c_t *, const c_t *); |
88 | |
89 | using gemv_fptr_t = void (*)(const dim_t *, const dim_t *, const float *, |
90 | const a_t *, const dim_t *, const b_t *, const dim_t *, c_t *, |
91 | const dim_t *); |
92 | |
93 | using gemv_s8s8s32_fptr_t |
94 | = void (*)(const dim_t, const dim_t, const float, const int8_t *, |
95 | const dim_t, const int8_t *, const float, int32_t *); |
96 | |
97 | using gemv_s8u8s32_fptr_t |
98 | = void (*)(const dim_t, const dim_t, const float, const int8_t *, |
99 | const dim_t, const uint8_t *, const float, int32_t *); |
100 | |
101 | using gemv_u8s8s32_fptr_t |
102 | = void (*)(const dim_t, const dim_t, const float, const uint8_t *, |
103 | const dim_t, const int8_t *, const float, int32_t *); |
104 | |
105 | // gemm kernels |
106 | copy_a_fptr_t copyA = nullptr; |
107 | copy_b_fptr_t copyB = nullptr; |
108 | gemm_fptr_t kernel[2][2][2] = {{{nullptr}}}; |
109 | |
110 | // gemv kernels |
111 | gemv_fptr_t gemv_kernel[2] = {nullptr}; |
112 | gemv_s8s8s32_fptr_t gemv_s8s8s32_kernel = nullptr; |
113 | gemv_s8u8s32_fptr_t gemv_s8u8s32_kernel = nullptr; |
114 | gemv_u8s8s32_fptr_t gemv_u8s8s32_kernel = nullptr; |
115 | |
116 | // copyA[trans][sum] |
117 | static copy_a_fptr_t copy_a_kern[2][2]; |
118 | |
119 | // copyB[trans][sum] |
120 | static copy_b_fptr_t copy_b_kern[2][2]; |
121 | |
122 | // kern[beta0][alpha1][col_off][row_off] |
123 | static gemm_fptr_t kern[2][2][2][2]; |
124 | |
125 | // gemv_kern[trans] |
126 | static gemv_fptr_t gemv_kern[2]; |
127 | |
128 | static gemv_s8s8s32_fptr_t gemv_s8s8s32_kern; |
129 | static gemv_s8u8s32_fptr_t gemv_s8u8s32_kern; |
130 | static gemv_u8s8s32_fptr_t gemv_u8s8s32_kern; |
131 | |
132 | template <bool is_trans> |
133 | static void copy_a_sum_ref(const dim_t *p_k, const dim_t *p_m, |
134 | const a_t *src, const dim_t *p_ld, const float *p_alpha, a_t *dst, |
135 | const dim_t *dummy1, const dim_t *dummy2, c_t *a_row_sum) { |
136 | |
137 | copy_a_kern[is_trans][no_sum]( |
138 | p_k, p_m, src, p_ld, p_alpha, dst, dummy1, dummy2, a_row_sum); |
139 | |
140 | dim_t k = *p_k; |
141 | dim_t m = *p_m; |
142 | dim_t ld = *p_ld; |
143 | |
144 | // Calculate op(A) row sum. |
145 | if (!is_trans) { |
146 | PRAGMA_OMP_SIMD() |
147 | for (dim_t i = 0; i < m; i++) |
148 | a_row_sum[i] = 0; |
149 | |
150 | for (dim_t j = 0; j < k; j++) { |
151 | PRAGMA_OMP_SIMD() |
152 | for (dim_t i = 0; i < m; i++) { |
153 | a_row_sum[i] += src[i + j * ld]; |
154 | } |
155 | } |
156 | } else { |
157 | for (dim_t i = 0; i < m; i++) { |
158 | c_t acc = 0; |
159 | |
160 | PRAGMA_OMP_SIMD(reduction(+ : acc)) |
161 | for (dim_t j = 0; j < k; j++) { |
162 | acc += src[j + i * ld]; |
163 | } |
164 | |
165 | a_row_sum[i] = acc; |
166 | } |
167 | } |
168 | } |
169 | |
170 | template <bool is_trans> |
171 | static void copy_b_sum_ref(const dim_t *p_k, const dim_t *p_n, |
172 | const b_t *src, const dim_t *p_ld, const float *alpha, b_t *dst, |
173 | const dim_t *dummy1, const dim_t *dummy2, c_t *b_col_sum) { |
174 | |
175 | copy_b_kern[is_trans][no_sum]( |
176 | p_k, p_n, src, p_ld, alpha, dst, dummy1, dummy2, b_col_sum); |
177 | |
178 | dim_t k = *p_k; |
179 | dim_t n = *p_n; |
180 | dim_t ld = *p_ld; |
181 | |
182 | // Calculate op(B) column sum. |
183 | if (!is_trans) { |
184 | for (dim_t j = 0; j < n; j++) { |
185 | c_t acc = 0; |
186 | |
187 | PRAGMA_OMP_SIMD(reduction(+ : acc)) |
188 | for (dim_t i = 0; i < k; i++) |
189 | acc += src[i + j * ld]; |
190 | |
191 | b_col_sum[j] = acc; |
192 | } |
193 | } else { |
194 | PRAGMA_OMP_SIMD() |
195 | for (dim_t j = 0; j < n; j++) |
196 | b_col_sum[j] = 0; |
197 | |
198 | for (dim_t i = 0; i < k; i++) { |
199 | PRAGMA_OMP_SIMD() |
200 | for (dim_t j = 0; j < n; j++) |
201 | b_col_sum[j] += src[j + i * ld]; |
202 | } |
203 | } |
204 | } |
205 | |
206 | bool force_nocopy; |
207 | |
208 | gemm_info_t(const char *transA, const char *transB, const char *offsetC, |
209 | const dim_t *m, const dim_t *n, const dim_t *k, const float *alpha, |
210 | const a_t *a, const dim_t *lda, const a_t *oa, const b_t *b, |
211 | const dim_t *ldb, const b_t *ob, const float *beta, c_t *c, |
212 | const dim_t *ldc, const c_t *oc, bool force_nocopy, |
213 | pack_type packing, gemm_pack_storage_t *pack_dst, |
214 | bool measure_only); |
215 | |
216 | bool hasKernels(void); |
217 | |
218 | void update_blocking(const gemm_threading_t &thread_info); |
219 | |
220 | private: |
221 | void jit_init(void); |
222 | }; |
223 | |
224 | } // namespace x64 |
225 | } // namespace cpu |
226 | } // namespace impl |
227 | } // namespace dnnl |
228 | |
229 | #endif // CPU_X64_GEMM_GEMM_INFO_HPP |
230 | |