1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <atomic>
18
19#include <assert.h>
20#include <float.h>
21#include <math.h>
22
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/cpu_primitive.hpp"
29#include "cpu/platform.hpp"
30
31#include "cpu/gemm/gemm.hpp"
32
33#include "cpu/binary_injector_utils.hpp"
34#include "cpu/matmul/gemm_bf16_matmul.hpp"
35#include "cpu/matmul/matmul_utils.hpp"
36#include "cpu/scale_utils.hpp"
37
38namespace dnnl {
39namespace impl {
40namespace cpu {
41namespace matmul {
42
43using namespace data_type;
44
45template <impl::data_type_t dst_type>
46status_t gemm_bf16_matmul_t<dst_type>::pd_t::init(engine_t *engine) {
47 auto check_bias = [&]() -> bool {
48 return !with_bias()
49 || (utils::one_of(weights_md(1)->data_type, f32, bf16)
50 && is_bias_1xN());
51 };
52
53 bool ok = !has_zero_dim_memory() && src_md()->data_type == src_type
54 && weights_md()->data_type == weights_type
55 && desc()->accum_data_type == acc_type
56 && dst_md()->data_type == dst_type
57 && platform::has_data_type_support(data_type::bf16) && check_bias()
58 && attr()->has_default_values(
59 primitive_attr_t::skip_mask_t::scales_runtime
60 | primitive_attr_t::skip_mask_t::post_ops)
61 && set_default_formats()
62 && attr_.set_default_formats(dst_md(0)) == status::success
63 && gemm_based::check_gemm_compatible_formats(*this);
64 if (!ok) return status::unimplemented;
65
66 CHECK(check_and_configure_attributes());
67
68 nthr_ = dnnl_get_max_threads();
69 gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t), nthr_);
70 auto scratchpad = scratchpad_registry().registrar();
71 book_precomputed_scales(scratchpad, attr()->scales_, N());
72
73 return status::success;
74}
75
76static bool should_gemm_execute_sum_po(const gemm_based::params_t &params,
77 impl::data_type_t dst_type) noexcept {
78 const auto &po = params.pp_attr_.post_ops_;
79 static constexpr int sum_idx = 0;
80 return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx)
81 && dst_type == data_type::f32 && params.gemm_applies_output_scales_
82 && po.entry_[sum_idx].sum.zero_point == 0;
83}
84
85template <impl::data_type_t dst_type>
86status_t gemm_bf16_matmul_t<dst_type>::pd_t::check_and_configure_attributes() {
87 auto check_attr_scales = [&]() -> bool {
88 using namespace data_type;
89 const std::vector<int> supported_args
90 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
91 bool ok = attr()->scales_.has_default_values(supported_args);
92 for (int arg : supported_args) {
93 const auto &mask = attr()->scales_.get(arg).mask_;
94 if (arg == DNNL_ARG_WEIGHTS)
95 ok = ok && (mask == 0 || mask == (1 << (dst_md()->ndims - 1)));
96 else
97 ok = ok && (mask == 0);
98 }
99
100 if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values()
101 && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()
102 && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) {
103 // This case requires scratchpad with unknown size
104 if (N() == DNNL_RUNTIME_DIM_VAL) return false;
105 }
106 return ok;
107 };
108
109 auto check_attr_post_ops = [&]() -> bool {
110 using namespace primitive_kind;
111 const auto &post_ops = attr()->post_ops_;
112 static const bcast_set_t enabled_bcast_strategy {
113 broadcasting_strategy_t::scalar,
114 broadcasting_strategy_t::per_oc,
115 broadcasting_strategy_t::per_oc_spatial,
116 broadcasting_strategy_t::per_mb_spatial,
117 broadcasting_strategy_t::per_mb_w,
118 broadcasting_strategy_t::per_w,
119 broadcasting_strategy_t::no_broadcast};
120 const bool is_binary_po_per_oc
121 = binary_injector_utils::bcast_strategy_present(
122 binary_injector_utils::extract_bcast_strategies(
123 post_ops.entry_, dst_md()),
124 broadcasting_strategy_t::per_oc);
125 return cpu::inner_product_utils::post_ops_ok(
126 post_ops, dst_md(), enabled_bcast_strategy)
127 && IMPLICATION(is_binary_po_per_oc,
128 gemm_based::check_gemm_binary_per_oc_compatible_formats(
129 *this));
130 };
131
132 // check basic attributes
133 if (!check_attr_scales()) return status::unimplemented;
134
135 // set state
136 CHECK(params_.pp_attr_.copy_from(*attr()));
137 params_.gemm_applies_output_scales_
138 = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias();
139
140 if (params_.gemm_applies_output_scales_) {
141 params_.pp_attr_.scales_.reset(DNNL_ARG_SRC);
142 params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS);
143 }
144
145 // check post-ops
146 if (!check_attr_post_ops()) return status::unimplemented;
147 const bool sum_po_via_gemm_beta
148 = should_gemm_execute_sum_po(params_, dst_type);
149 // set state
150 params_.dst_is_acc_ = dst_type == data_type::f32
151 && IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1,
152 sum_po_via_gemm_beta);
153
154 if (sum_po_via_gemm_beta) {
155 // set state
156 const auto &po = params_.pp_attr_.post_ops_;
157 static constexpr int sum_idx = 0;
158 params_.gemm_beta_ = po.entry_[sum_idx].sum.scale;
159 }
160
161 // set state
162 params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias()
163 || !params_.pp_attr_.has_default_values();
164
165 return status::success;
166}
167
168template <impl::data_type_t dst_type>
169bool gemm_bf16_matmul_t<dst_type>::should_skip_sum_po() const noexcept {
170 return should_gemm_execute_sum_po(pd()->params(), dst_type);
171}
172
173template <impl::data_type_t dst_type>
174status_t gemm_bf16_matmul_t<dst_type>::execute_ref(
175 const exec_ctx_t &ctx) const {
176 using namespace binary_injector_utils;
177 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
178 auto weights = CTX_IN_MEM(const weights_data_t *, DNNL_ARG_WEIGHTS);
179 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
180 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
181 const auto &po = this->pd()->attr()->post_ops_;
182 const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx);
183
184 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
185 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
186 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
187
188 const int ndims = pd()->ndims();
189
190 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
191 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
192 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
193
194 auto scratchpad = ctx.get_scratchpad_grantor();
195 const float *scales = precompute_scales(scratchpad, src_scales, wei_scales,
196 dst_d.dims()[ndims - 1], pd()->attr());
197
198 if (src_d.has_zero_dim() || weights_d.has_zero_dim()
199 || dst_d.has_zero_dim())
200 return status::success;
201
202 matmul_helper_t helper(src_d, weights_d, dst_d);
203 const int batch_ndims = ndims - 2;
204 dim_t M = helper.M();
205 const dim_t N = helper.N();
206 const dim_t K = helper.K();
207 const dim_t batch = helper.batch();
208 const dim_t batch_without_dim0
209 = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0;
210 const dim_t batch_without_dim01
211 = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1;
212 const char transA = helper.transA();
213 const char transB = helper.transB();
214 const dim_t lda = helper.lda();
215 const dim_t ldb = helper.ldb();
216 const dim_t ldc = helper.ldc();
217 const int nthr = pd()->nthr_;
218
219 const gemm_based::params_t &params = pd()->params();
220 const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
221 ? helper.can_fuse_src_batch_dims()
222 : params.can_fuse_src_batch_dims_;
223 const dim_t acc_stride = gemm_based::get_scratchpad_size(
224 batch, M, N, can_fuse_src_batch_dims, nthr);
225 bool dst_is_acc = params.dst_is_acc_;
226 acc_data_t *acc = dst_is_acc
227 ? (acc_data_t *)dst
228 : ctx.get_scratchpad_grantor().template get<acc_data_t>(
229 memory_tracking::names::key_matmul_dst_in_acc_dt);
230 // case: dynamic sizes
231 bool need_free_acc = false;
232 if (acc == nullptr) {
233 acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride
234 * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
235 64);
236 if (acc == nullptr) return status::out_of_memory;
237 need_free_acc = true;
238 }
239
240 const float alpha = params.get_gemm_alpha(scales);
241 const float beta = params.gemm_beta_;
242 const dim_t acc_ldc = dst_is_acc ? ldc : N;
243 const int scale_idx_mult
244 = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_
245 == (1 << (ndims - 1));
246
247 std::atomic<status_t> st(status::success);
248 // use parallel over batch when binary po with channel bcast
249 // (except batch == 1)
250 bool is_binary_po_per_oc;
251 bool is_binary_po_per_oc_sp;
252 bool is_binary_po_channel_bcast;
253 std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
254 is_binary_po_channel_bcast)
255 = bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
256 broadcasting_strategy_t::per_oc,
257 broadcasting_strategy_t::per_oc_spatial,
258 broadcasting_strategy_t::per_mb_spatial);
259 // if batched, parralel over batch for per_mb_sp and per_oc binary
260 // post-op broadcast
261 const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
262 && IMPLICATION(
263 is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
264 const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
265 if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
266 const int src_mask
267 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
268 const int wei_mask
269 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
270 const size_t bia_dt_size = !pd()->with_bias()
271 ? 0
272 : types::data_type_size(pd()->weights_md(1)->data_type);
273 const size_t work_amount = (size_t)batch * M * N;
274 const size_t work_per_batch = (size_t)M * N;
275
276 // NOTE: inside lambda, type cast variables captured by reference using
277 // either c-like "(type)var" or functional "type(var)" notation in order
278 // to avoid gcc bug with c++14 standard. Otherwise, capture by value.
279 parallel(nthr, [=, &st](int ithr, int nthr) {
280 size_t t_work_start {0}, t_work_end {0};
281 balance211(work_amount, nthr, ithr, t_work_start, t_work_end);
282
283 dim_t cur_b {0}, cur_m {0}, cur_n {0};
284 dims_t s_dims_idx, w_dims_idx, d_dims_idx;
285 size_t i_work = t_work_start;
286 const bool reuse_acc = acc != (acc_data_t *)dst;
287 acc_data_t *curr_acc
288 = reuse_acc ? acc + ithr * acc_stride : nullptr;
289
290 while (i_work < t_work_end) {
291 utils::nd_iterator_init(
292 i_work, cur_b, batch, cur_m, M, cur_n, N);
293
294 utils::l_dims_by_l_offset(
295 d_dims_idx, i_work, dst_d.dims(), ndims);
296 utils::copy_dims_with_mask(
297 s_dims_idx, d_dims_idx, batch_ndims, src_mask);
298 s_dims_idx[ndims - 2] = cur_m;
299 s_dims_idx[ndims - 1] = 0; // k idx is always 0
300
301 utils::copy_dims_with_mask(
302 w_dims_idx, d_dims_idx, batch_ndims, wei_mask);
303 w_dims_idx[ndims - 2] = 0; // k idx is always 0
304 w_dims_idx[ndims - 1] = cur_n;
305 const src_data_t *curr_src = src + src_d.off_v(s_dims_idx);
306 const weights_data_t *curr_weights
307 = weights + weights_d.off_v(w_dims_idx);
308 const dim_t dst_off = dst_d.off_v(d_dims_idx);
309 dst_data_t *curr_dst = dst + dst_off;
310 if (!reuse_acc) curr_acc = acc + dst_off;
311 dim_t gemm_M {0}, gemm_N {0};
312
313 size_t matrix_offset;
314 const size_t rem_work = t_work_end - i_work;
315 if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) {
316 // parallel over batch
317 gemm_M = M;
318 gemm_N = N;
319 matrix_offset = 0;
320 } else if (rem_work >= (size_t)N && cur_n == 0) {
321 // parallel over M
322 gemm_M = nstl::min(
323 (size_t)(M - cur_m), (size_t)(rem_work / N));
324 gemm_N = N;
325 matrix_offset = cur_n + cur_m * N;
326 } else {
327 // parallel over N
328 gemm_M = 1;
329 gemm_N = nstl::min((size_t)(N - cur_n), rem_work);
330 matrix_offset = cur_n + cur_m * N;
331 }
332
333 status_t st_thr = gemm_bf16bf16f32(&transB, &transA, &gemm_N,
334 &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda,
335 &beta, curr_acc, &acc_ldc);
336 if (st_thr != status::success) {
337 st = st_thr;
338 return;
339 }
340
341 if (params.has_pp_kernel_) {
342 const float *pp_scales
343 = params.get_post_processing_scales(scales);
344 const size_t dst_logical_off = i_work;
345 const size_t dim1_off = helper.ndims() > 3
346 ? ((cur_b % batch_without_dim0)
347 / batch_without_dim01)
348 : cur_m;
349 // offset for case with post-op broadcast_channel
350 const size_t matrix_per_first_batch_off = helper.ndims() > 3
351 ? M * N * (cur_b / batch_without_dim0)
352 + matrix_offset
353 : 0;
354 const ptrdiff_t oc_off = i_work % N;
355 (*pp_kernel_)(curr_dst, curr_acc,
356 bias + oc_off * bia_dt_size,
357 pp_scales + oc_off * scale_idx_mult, dst_scales[0],
358 0, dst_logical_off, dim1_off, gemm_M * gemm_N,
359 static_cast<size_t>(N), ldc, nullptr,
360 post_ops_binary_rhs_arg_vec.data(), dst,
361 matrix_per_first_batch_off, ctx, *pd()->dst_md());
362 }
363 i_work += gemm_M * gemm_N;
364 }
365 });
366 } else {
367 // collapse batch into M, if weights batch dimensions are broadcasted.
368 M = M * batch;
369
370 st = gemm_bf16bf16f32(&transB, &transA, &N, &M, &K, &alpha, weights,
371 &ldb, src, &lda, &beta, acc, &acc_ldc);
372
373 if (st == status::success && params.has_pp_kernel_) {
374 const bool force_sequential = pp_kernel_->sequential_kernel();
375 const float *pp_scales = params.get_post_processing_scales(scales);
376
377 parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) {
378 size_t start {}, end {};
379 balance211((size_t)(M * N), nthr, ithr, start, end);
380 const size_t dst_logical_off = start;
381 const size_t dim1_off = start % N;
382 (*pp_kernel_)(dst, acc, bias, pp_scales, dst_scales[0], start,
383 dst_logical_off, dim1_off, end, (size_t)N, ldc, nullptr,
384 post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
385 *pd()->dst_md());
386 });
387 }
388 }
389
390 if (need_free_acc) free(acc);
391
392 return st;
393}
394
395using namespace data_type;
396template struct gemm_bf16_matmul_t<data_type::f32>;
397template struct gemm_bf16_matmul_t<data_type::bf16>;
398
399} // namespace matmul
400} // namespace cpu
401} // namespace impl
402} // namespace dnnl
403