1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | |
19 | #include "common/bfloat16.hpp" |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/x64/gemm_bf16_inner_product.hpp" |
25 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
26 | |
27 | #include "cpu/binary_injector_utils.hpp" |
28 | #include "cpu/cpu_primitive.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace dnnl::impl::status; |
36 | using namespace dnnl::impl::prop_kind; |
37 | using namespace dnnl::impl::data_type; |
38 | using namespace dnnl::impl::format_tag; |
39 | using namespace dnnl::impl::primitive_kind; |
40 | using namespace memory_tracking::names; |
41 | using namespace dnnl::impl::cpu::x64::bf16_support; |
42 | |
43 | template <data_type_t dst_data_type> |
44 | status_t gemm_bf16_inner_product_fwd_t<dst_data_type>::execute_forward( |
45 | const exec_ctx_t &ctx) const { |
46 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
47 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
48 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
49 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
50 | const auto post_ops_binary_rhs_arg_vec |
51 | = binary_injector_utils::prepare_binary_args( |
52 | this->pd()->attr()->post_ops_, ctx); |
53 | |
54 | const dim_t M = pd()->OC(); |
55 | const dim_t N = pd()->MB(); |
56 | const dim_t K = pd()->IC_total_padded(); |
57 | |
58 | const auto &wmd = *pd()->weights_md(); |
59 | const auto &smd = *pd()->src_md(); |
60 | bool wei_tr = wmd.format_desc.blocking.strides[0] != 1; |
61 | // check if MB is the leading dimension |
62 | bool src_tr = smd.format_desc.blocking.strides[0] == 1 && K > 1; |
63 | |
64 | acc_data_t *acc = pd()->dst_is_acc_ |
65 | ? (acc_data_t *)dst |
66 | : ctx.get_scratchpad_grantor().template get<acc_data_t>( |
67 | key_iprod_int_dat_in_acc_dt); |
68 | |
69 | float alpha = 1.0; |
70 | status_t st = gemm_bf16bf16f32(wei_tr ? "T" : "N" , src_tr ? "T" : "N" , &M, |
71 | &N, &K, &alpha, weights, wei_tr ? &K : &M, src, src_tr ? &N : &K, |
72 | &beta_, acc, &M); |
73 | if (st != status::success) return st; |
74 | |
75 | if (postops_in_ip_) { |
76 | const bool force_sequential = pp_kernel_->sequential_kernel(); |
77 | parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { |
78 | size_t start = 0, end = 0; |
79 | size_t work_size = M * N; |
80 | balance211(work_size, nthr, ithr, start, end); |
81 | const size_t dst_logical_off = start; |
82 | const size_t dim1_off = start % M; |
83 | (*pp_kernel_)(dst, acc, bias, nullptr, 1.0f, start, dst_logical_off, |
84 | dim1_off, end, 0, 0, nullptr, |
85 | post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx, |
86 | *pd()->dst_md()); |
87 | }); |
88 | } |
89 | |
90 | return st; |
91 | } |
92 | |
93 | template <data_type_t diff_src_data_type> |
94 | status_t |
95 | gemm_bf16_inner_product_bwd_data_t<diff_src_data_type>::execute_backward_data( |
96 | const exec_ctx_t &ctx) const { |
97 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
98 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
99 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
100 | |
101 | const dim_t M = pd()->IC_total_padded(); |
102 | const dim_t N = pd()->MB(); |
103 | const dim_t K = pd()->OC(); |
104 | |
105 | const auto &wmd = *pd()->weights_md(); |
106 | const auto &smd = *pd()->diff_src_md(); |
107 | bool wei_tr = wmd.format_desc.blocking.strides[0] == 1; |
108 | // check if MB is the leading dimension |
109 | bool dsrc_tr = smd.format_desc.blocking.strides[0] == 1 && M > 1; |
110 | |
111 | acc_data_t *acc = pd()->diff_src_is_acc_ |
112 | ? (acc_data_t *)diff_src |
113 | : ctx.get_scratchpad_grantor().template get<acc_data_t>( |
114 | key_iprod_int_dat_in_acc_dt); |
115 | |
116 | float alpha = 1.0, beta = 0.0; |
117 | status_t st = status::success; |
118 | if (dsrc_tr) |
119 | st = gemm_bf16bf16f32(wei_tr ? "T" : "N" , "N" , &K, &M, &N, &alpha, |
120 | diff_dst, &K, weights, wei_tr ? &K : &M, &beta, acc, &N); |
121 | else |
122 | st = gemm_bf16bf16f32(wei_tr ? "T" : "N" , "N" , &M, &N, &K, &alpha, |
123 | weights, wei_tr ? &K : &M, diff_dst, &K, &beta, acc, &M); |
124 | if (st != status::success) return st; |
125 | |
126 | if (!pd()->diff_src_is_acc_) { |
127 | parallel(0, [&](int ithr, int nthr) { |
128 | size_t start = 0, end = 0; |
129 | size_t work_size = M * N; |
130 | balance211(work_size, nthr, ithr, start, end); |
131 | if (end > start) |
132 | cvt_float_to_bfloat16((bfloat16_t *)&diff_src[start], |
133 | (const float *)&acc[start], end - start); |
134 | }); |
135 | } |
136 | |
137 | return status::success; |
138 | } |
139 | |
140 | template <data_type_t diff_wei_data_type> |
141 | status_t gemm_bf16_inner_product_bwd_weights_t<diff_wei_data_type>:: |
142 | execute_backward_weights(const exec_ctx_t &ctx) const { |
143 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
144 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
145 | auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
146 | |
147 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
148 | diff_dst += diff_dst_d.offset0(); |
149 | |
150 | const dim_t MB = pd()->MB(); |
151 | const dim_t OC = pd()->OC(); |
152 | const dim_t IC = pd()->IC_total_padded(); |
153 | |
154 | const auto &wmd = *pd()->diff_weights_md(); |
155 | const auto &smd = *pd()->src_md(); |
156 | bool wei_tr = wmd.format_desc.blocking.strides[0] == 1; |
157 | // check if MB is the leading dimension |
158 | bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1; |
159 | |
160 | acc_data_t *acc = pd()->diff_wei_is_acc_ |
161 | ? (acc_data_t *)diff_weights |
162 | : ctx.get_scratchpad_grantor().template get<acc_data_t>( |
163 | key_iprod_int_dat_in_acc_dt); |
164 | |
165 | float alpha = 1.0, beta = 0.0; |
166 | status_t st = status::success; |
167 | if (wei_tr) |
168 | st = gemm_bf16bf16f32("N" , src_tr ? "N" : "T" , &OC, &IC, &MB, &alpha, |
169 | diff_dst, &OC, src, src_tr ? &MB : &IC, &beta, acc, &OC); |
170 | else |
171 | st = gemm_bf16bf16f32("N" , src_tr ? "N" : "T" , &IC, &OC, &MB, &alpha, |
172 | src, src_tr ? &MB : &IC, diff_dst, &OC, &beta, acc, &IC); |
173 | |
174 | if (st != status::success) return st; |
175 | |
176 | if (!pd()->diff_wei_is_acc_) { |
177 | parallel(0, [&](int ithr, int nthr) { |
178 | constexpr size_t blksize = 64; |
179 | size_t start = 0, end = 0; |
180 | size_t work_size = OC * IC; |
181 | balance211( |
182 | utils::div_up(work_size, blksize), nthr, ithr, start, end); |
183 | start = std::min(work_size, start * blksize); |
184 | end = std::min(work_size, end * blksize); |
185 | if (end > start) { |
186 | cvt_float_to_bfloat16((bfloat16_t *)&diff_weights[start], |
187 | (const float *)&acc[start], end - start); |
188 | } |
189 | }); |
190 | } |
191 | |
192 | execute_backward_bias(ctx); |
193 | |
194 | return status::success; |
195 | } |
196 | |
197 | template <data_type_t diff_wei_data_type> |
198 | void gemm_bf16_inner_product_bwd_weights_t<diff_wei_data_type>:: |
199 | execute_backward_bias(const exec_ctx_t &ctx) const { |
200 | if (!pd()->with_bias()) return; |
201 | |
202 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
203 | auto diff_bias = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_BIAS); |
204 | |
205 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
206 | const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); |
207 | |
208 | diff_dst += diff_dst_d.offset0(); |
209 | diff_bias += diff_bias_d.data_type_size() * diff_bias_d.offset0(); |
210 | |
211 | const dim_t MB = pd()->MB(); |
212 | const dim_t OC = pd()->OC(); |
213 | |
214 | constexpr dim_t blksize = pd_t::bias_blksize; |
215 | const dim_t OCB = utils::div_up(OC, blksize); |
216 | |
217 | dim_t OC_per_thread {0}; |
218 | int nthr_OCB {0}, nthr_MB {0}; |
219 | pd()->get_bias_partitioning(OC_per_thread, nthr_OCB, nthr_MB); |
220 | |
221 | const bool diff_bias_is_acc |
222 | = nthr_MB == 1 && diff_bias_d.data_type() == data_type::f32; |
223 | float *diff_bias_acc = diff_bias_is_acc |
224 | ? (float *)diff_bias |
225 | : (float *)ctx.get_scratchpad_grantor().template get<acc_data_t>( |
226 | key_iprod_bias_bf16_convert_wsp); |
227 | |
228 | parallel(pd()->bias_reduction_nthr_, [&](int ithr, int nthr) { |
229 | if (ithr < nthr_OCB * nthr_MB) { |
230 | const int ithr_MB = ithr / nthr_OCB; |
231 | const int ithr_OCB = ithr % nthr_OCB; |
232 | |
233 | dim_t ocb_s {0}, ocb_e {0}; |
234 | balance211(OCB, nthr_OCB, ithr_OCB, ocb_s, ocb_e); |
235 | const dim_t oc_s = std::min(ocb_s * blksize, OC); |
236 | const dim_t oc_e = std::min(ocb_e * blksize, OC); |
237 | const dim_t oc_len = oc_e - oc_s; |
238 | |
239 | dim_t mb_s {0}, mb_e {0}; |
240 | balance211(MB, nthr_MB, ithr_MB, mb_s, mb_e); |
241 | const dim_t mb_len = mb_e - mb_s; |
242 | |
243 | const dim_t db_offset = diff_bias_is_acc |
244 | ? oc_s |
245 | : (ithr_OCB * nthr_MB + ithr_MB) * OC_per_thread; |
246 | float *db = diff_bias_acc + db_offset; |
247 | |
248 | PRAGMA_OMP_SIMD() |
249 | for (dim_t oc = 0; oc < oc_len; ++oc) |
250 | db[oc] = 0; |
251 | |
252 | (*bias_reduction_)(db, &((bfloat16_t *)diff_dst)[mb_s * OC + oc_s], |
253 | (size_t)oc_len, (size_t)mb_len); |
254 | |
255 | if (!diff_bias_is_acc && nthr_MB == 1) |
256 | cvt_float_to_bfloat16( |
257 | &((bfloat16_t *)diff_bias)[oc_s], db, oc_len); |
258 | } |
259 | }); |
260 | |
261 | if (nthr_MB == 1) return; // no reduction required |
262 | |
263 | parallel(pd()->bias_reduction_nthr_, [&](int ithr, int nthr) { |
264 | if (ithr < nthr_OCB) { |
265 | const int ithr_OCB = ithr; |
266 | |
267 | dim_t ocb_s {0}, ocb_e {0}; |
268 | balance211(OCB, nthr_OCB, ithr_OCB, ocb_s, ocb_e); |
269 | const dim_t oc_s = std::min(ocb_s * blksize, OC); |
270 | const dim_t oc_e = std::min(ocb_e * blksize, OC); |
271 | const dim_t oc_len = oc_e - oc_s; |
272 | |
273 | float *db = diff_bias_acc + ithr_OCB * nthr_MB * OC_per_thread; |
274 | |
275 | for (dim_t thr_MB = 1; thr_MB < nthr_MB; ++thr_MB) { |
276 | const float *thr_db = db + thr_MB * OC_per_thread; |
277 | |
278 | PRAGMA_OMP_SIMD() |
279 | for (dim_t oc = 0; oc < oc_len; ++oc) |
280 | db[oc] += thr_db[oc]; |
281 | } |
282 | |
283 | if (diff_bias_d.data_type() == data_type::f32) { |
284 | float *res = &((float *)diff_bias)[oc_s]; |
285 | |
286 | PRAGMA_OMP_SIMD() |
287 | for (dim_t oc = 0; oc < oc_len; ++oc) |
288 | res[oc] = db[oc]; |
289 | } else { |
290 | cvt_float_to_bfloat16( |
291 | &((bfloat16_t *)diff_bias)[oc_s], db, oc_len); |
292 | } |
293 | } |
294 | }); |
295 | } |
296 | |
297 | template struct gemm_bf16_inner_product_fwd_t<data_type::f32>; |
298 | template struct gemm_bf16_inner_product_fwd_t<data_type::bf16>; |
299 | template struct gemm_bf16_inner_product_bwd_data_t<data_type::f32>; |
300 | template struct gemm_bf16_inner_product_bwd_data_t<data_type::bf16>; |
301 | template struct gemm_bf16_inner_product_bwd_weights_t<data_type::f32>; |
302 | template struct gemm_bf16_inner_product_bwd_weights_t<data_type::bf16>; |
303 | |
304 | } // namespace x64 |
305 | } // namespace cpu |
306 | } // namespace impl |
307 | } // namespace dnnl |
308 | |
309 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
310 | |