1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/x64/gemm_bf16_inner_product.hpp"
25#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
26
27#include "cpu/binary_injector_utils.hpp"
28#include "cpu/cpu_primitive.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace dnnl::impl::status;
36using namespace dnnl::impl::prop_kind;
37using namespace dnnl::impl::data_type;
38using namespace dnnl::impl::format_tag;
39using namespace dnnl::impl::primitive_kind;
40using namespace memory_tracking::names;
41using namespace dnnl::impl::cpu::x64::bf16_support;
42
43template <data_type_t dst_data_type>
44status_t gemm_bf16_inner_product_fwd_t<dst_data_type>::execute_forward(
45 const exec_ctx_t &ctx) const {
46 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
47 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
48 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
49 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
50 const auto post_ops_binary_rhs_arg_vec
51 = binary_injector_utils::prepare_binary_args(
52 this->pd()->attr()->post_ops_, ctx);
53
54 const dim_t M = pd()->OC();
55 const dim_t N = pd()->MB();
56 const dim_t K = pd()->IC_total_padded();
57
58 const auto &wmd = *pd()->weights_md();
59 const auto &smd = *pd()->src_md();
60 bool wei_tr = wmd.format_desc.blocking.strides[0] != 1;
61 // check if MB is the leading dimension
62 bool src_tr = smd.format_desc.blocking.strides[0] == 1 && K > 1;
63
64 acc_data_t *acc = pd()->dst_is_acc_
65 ? (acc_data_t *)dst
66 : ctx.get_scratchpad_grantor().template get<acc_data_t>(
67 key_iprod_int_dat_in_acc_dt);
68
69 float alpha = 1.0;
70 status_t st = gemm_bf16bf16f32(wei_tr ? "T" : "N", src_tr ? "T" : "N", &M,
71 &N, &K, &alpha, weights, wei_tr ? &K : &M, src, src_tr ? &N : &K,
72 &beta_, acc, &M);
73 if (st != status::success) return st;
74
75 if (postops_in_ip_) {
76 const bool force_sequential = pp_kernel_->sequential_kernel();
77 parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
78 size_t start = 0, end = 0;
79 size_t work_size = M * N;
80 balance211(work_size, nthr, ithr, start, end);
81 const size_t dst_logical_off = start;
82 const size_t dim1_off = start % M;
83 (*pp_kernel_)(dst, acc, bias, nullptr, 1.0f, start, dst_logical_off,
84 dim1_off, end, 0, 0, nullptr,
85 post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
86 *pd()->dst_md());
87 });
88 }
89
90 return st;
91}
92
93template <data_type_t diff_src_data_type>
94status_t
95gemm_bf16_inner_product_bwd_data_t<diff_src_data_type>::execute_backward_data(
96 const exec_ctx_t &ctx) const {
97 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
98 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
99 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
100
101 const dim_t M = pd()->IC_total_padded();
102 const dim_t N = pd()->MB();
103 const dim_t K = pd()->OC();
104
105 const auto &wmd = *pd()->weights_md();
106 const auto &smd = *pd()->diff_src_md();
107 bool wei_tr = wmd.format_desc.blocking.strides[0] == 1;
108 // check if MB is the leading dimension
109 bool dsrc_tr = smd.format_desc.blocking.strides[0] == 1 && M > 1;
110
111 acc_data_t *acc = pd()->diff_src_is_acc_
112 ? (acc_data_t *)diff_src
113 : ctx.get_scratchpad_grantor().template get<acc_data_t>(
114 key_iprod_int_dat_in_acc_dt);
115
116 float alpha = 1.0, beta = 0.0;
117 status_t st = status::success;
118 if (dsrc_tr)
119 st = gemm_bf16bf16f32(wei_tr ? "T" : "N", "N", &K, &M, &N, &alpha,
120 diff_dst, &K, weights, wei_tr ? &K : &M, &beta, acc, &N);
121 else
122 st = gemm_bf16bf16f32(wei_tr ? "T" : "N", "N", &M, &N, &K, &alpha,
123 weights, wei_tr ? &K : &M, diff_dst, &K, &beta, acc, &M);
124 if (st != status::success) return st;
125
126 if (!pd()->diff_src_is_acc_) {
127 parallel(0, [&](int ithr, int nthr) {
128 size_t start = 0, end = 0;
129 size_t work_size = M * N;
130 balance211(work_size, nthr, ithr, start, end);
131 if (end > start)
132 cvt_float_to_bfloat16((bfloat16_t *)&diff_src[start],
133 (const float *)&acc[start], end - start);
134 });
135 }
136
137 return status::success;
138}
139
140template <data_type_t diff_wei_data_type>
141status_t gemm_bf16_inner_product_bwd_weights_t<diff_wei_data_type>::
142 execute_backward_weights(const exec_ctx_t &ctx) const {
143 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
144 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
145 auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
146
147 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
148 diff_dst += diff_dst_d.offset0();
149
150 const dim_t MB = pd()->MB();
151 const dim_t OC = pd()->OC();
152 const dim_t IC = pd()->IC_total_padded();
153
154 const auto &wmd = *pd()->diff_weights_md();
155 const auto &smd = *pd()->src_md();
156 bool wei_tr = wmd.format_desc.blocking.strides[0] == 1;
157 // check if MB is the leading dimension
158 bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1;
159
160 acc_data_t *acc = pd()->diff_wei_is_acc_
161 ? (acc_data_t *)diff_weights
162 : ctx.get_scratchpad_grantor().template get<acc_data_t>(
163 key_iprod_int_dat_in_acc_dt);
164
165 float alpha = 1.0, beta = 0.0;
166 status_t st = status::success;
167 if (wei_tr)
168 st = gemm_bf16bf16f32("N", src_tr ? "N" : "T", &OC, &IC, &MB, &alpha,
169 diff_dst, &OC, src, src_tr ? &MB : &IC, &beta, acc, &OC);
170 else
171 st = gemm_bf16bf16f32("N", src_tr ? "N" : "T", &IC, &OC, &MB, &alpha,
172 src, src_tr ? &MB : &IC, diff_dst, &OC, &beta, acc, &IC);
173
174 if (st != status::success) return st;
175
176 if (!pd()->diff_wei_is_acc_) {
177 parallel(0, [&](int ithr, int nthr) {
178 constexpr size_t blksize = 64;
179 size_t start = 0, end = 0;
180 size_t work_size = OC * IC;
181 balance211(
182 utils::div_up(work_size, blksize), nthr, ithr, start, end);
183 start = std::min(work_size, start * blksize);
184 end = std::min(work_size, end * blksize);
185 if (end > start) {
186 cvt_float_to_bfloat16((bfloat16_t *)&diff_weights[start],
187 (const float *)&acc[start], end - start);
188 }
189 });
190 }
191
192 execute_backward_bias(ctx);
193
194 return status::success;
195}
196
197template <data_type_t diff_wei_data_type>
198void gemm_bf16_inner_product_bwd_weights_t<diff_wei_data_type>::
199 execute_backward_bias(const exec_ctx_t &ctx) const {
200 if (!pd()->with_bias()) return;
201
202 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
203 auto diff_bias = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_BIAS);
204
205 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
206 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
207
208 diff_dst += diff_dst_d.offset0();
209 diff_bias += diff_bias_d.data_type_size() * diff_bias_d.offset0();
210
211 const dim_t MB = pd()->MB();
212 const dim_t OC = pd()->OC();
213
214 constexpr dim_t blksize = pd_t::bias_blksize;
215 const dim_t OCB = utils::div_up(OC, blksize);
216
217 dim_t OC_per_thread {0};
218 int nthr_OCB {0}, nthr_MB {0};
219 pd()->get_bias_partitioning(OC_per_thread, nthr_OCB, nthr_MB);
220
221 const bool diff_bias_is_acc
222 = nthr_MB == 1 && diff_bias_d.data_type() == data_type::f32;
223 float *diff_bias_acc = diff_bias_is_acc
224 ? (float *)diff_bias
225 : (float *)ctx.get_scratchpad_grantor().template get<acc_data_t>(
226 key_iprod_bias_bf16_convert_wsp);
227
228 parallel(pd()->bias_reduction_nthr_, [&](int ithr, int nthr) {
229 if (ithr < nthr_OCB * nthr_MB) {
230 const int ithr_MB = ithr / nthr_OCB;
231 const int ithr_OCB = ithr % nthr_OCB;
232
233 dim_t ocb_s {0}, ocb_e {0};
234 balance211(OCB, nthr_OCB, ithr_OCB, ocb_s, ocb_e);
235 const dim_t oc_s = std::min(ocb_s * blksize, OC);
236 const dim_t oc_e = std::min(ocb_e * blksize, OC);
237 const dim_t oc_len = oc_e - oc_s;
238
239 dim_t mb_s {0}, mb_e {0};
240 balance211(MB, nthr_MB, ithr_MB, mb_s, mb_e);
241 const dim_t mb_len = mb_e - mb_s;
242
243 const dim_t db_offset = diff_bias_is_acc
244 ? oc_s
245 : (ithr_OCB * nthr_MB + ithr_MB) * OC_per_thread;
246 float *db = diff_bias_acc + db_offset;
247
248 PRAGMA_OMP_SIMD()
249 for (dim_t oc = 0; oc < oc_len; ++oc)
250 db[oc] = 0;
251
252 (*bias_reduction_)(db, &((bfloat16_t *)diff_dst)[mb_s * OC + oc_s],
253 (size_t)oc_len, (size_t)mb_len);
254
255 if (!diff_bias_is_acc && nthr_MB == 1)
256 cvt_float_to_bfloat16(
257 &((bfloat16_t *)diff_bias)[oc_s], db, oc_len);
258 }
259 });
260
261 if (nthr_MB == 1) return; // no reduction required
262
263 parallel(pd()->bias_reduction_nthr_, [&](int ithr, int nthr) {
264 if (ithr < nthr_OCB) {
265 const int ithr_OCB = ithr;
266
267 dim_t ocb_s {0}, ocb_e {0};
268 balance211(OCB, nthr_OCB, ithr_OCB, ocb_s, ocb_e);
269 const dim_t oc_s = std::min(ocb_s * blksize, OC);
270 const dim_t oc_e = std::min(ocb_e * blksize, OC);
271 const dim_t oc_len = oc_e - oc_s;
272
273 float *db = diff_bias_acc + ithr_OCB * nthr_MB * OC_per_thread;
274
275 for (dim_t thr_MB = 1; thr_MB < nthr_MB; ++thr_MB) {
276 const float *thr_db = db + thr_MB * OC_per_thread;
277
278 PRAGMA_OMP_SIMD()
279 for (dim_t oc = 0; oc < oc_len; ++oc)
280 db[oc] += thr_db[oc];
281 }
282
283 if (diff_bias_d.data_type() == data_type::f32) {
284 float *res = &((float *)diff_bias)[oc_s];
285
286 PRAGMA_OMP_SIMD()
287 for (dim_t oc = 0; oc < oc_len; ++oc)
288 res[oc] = db[oc];
289 } else {
290 cvt_float_to_bfloat16(
291 &((bfloat16_t *)diff_bias)[oc_s], db, oc_len);
292 }
293 }
294 });
295}
296
297template struct gemm_bf16_inner_product_fwd_t<data_type::f32>;
298template struct gemm_bf16_inner_product_fwd_t<data_type::bf16>;
299template struct gemm_bf16_inner_product_bwd_data_t<data_type::f32>;
300template struct gemm_bf16_inner_product_bwd_data_t<data_type::bf16>;
301template struct gemm_bf16_inner_product_bwd_weights_t<data_type::f32>;
302template struct gemm_bf16_inner_product_bwd_weights_t<data_type::bf16>;
303
304} // namespace x64
305} // namespace cpu
306} // namespace impl
307} // namespace dnnl
308
309// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
310