1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <atomic> |
18 | |
19 | #include <assert.h> |
20 | #include <float.h> |
21 | #include <math.h> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/cpu_primitive.hpp" |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/gemm/gemm.hpp" |
32 | |
33 | #include "cpu/binary_injector_utils.hpp" |
34 | #include "cpu/matmul/gemm_bf16_matmul.hpp" |
35 | #include "cpu/matmul/matmul_utils.hpp" |
36 | #include "cpu/scale_utils.hpp" |
37 | |
38 | namespace dnnl { |
39 | namespace impl { |
40 | namespace cpu { |
41 | namespace matmul { |
42 | |
43 | using namespace data_type; |
44 | |
45 | template <impl::data_type_t dst_type> |
46 | status_t gemm_bf16_matmul_t<dst_type>::pd_t::init(engine_t *engine) { |
47 | auto check_bias = [&]() -> bool { |
48 | return !with_bias() |
49 | || (utils::one_of(weights_md(1)->data_type, f32, bf16) |
50 | && is_bias_1xN()); |
51 | }; |
52 | |
53 | bool ok = !has_zero_dim_memory() && src_md()->data_type == src_type |
54 | && weights_md()->data_type == weights_type |
55 | && desc()->accum_data_type == acc_type |
56 | && dst_md()->data_type == dst_type |
57 | && platform::has_data_type_support(data_type::bf16) && check_bias() |
58 | && attr()->has_default_values( |
59 | primitive_attr_t::skip_mask_t::scales_runtime |
60 | | primitive_attr_t::skip_mask_t::post_ops) |
61 | && set_default_formats() |
62 | && attr_.set_default_formats(dst_md(0)) == status::success |
63 | && gemm_based::check_gemm_compatible_formats(*this); |
64 | if (!ok) return status::unimplemented; |
65 | |
66 | CHECK(check_and_configure_attributes()); |
67 | |
68 | nthr_ = dnnl_get_max_threads(); |
69 | gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t), nthr_); |
70 | auto scratchpad = scratchpad_registry().registrar(); |
71 | book_precomputed_scales(scratchpad, attr()->scales_, N()); |
72 | |
73 | return status::success; |
74 | } |
75 | |
76 | static bool should_gemm_execute_sum_po(const gemm_based::params_t ¶ms, |
77 | impl::data_type_t dst_type) noexcept { |
78 | const auto &po = params.pp_attr_.post_ops_; |
79 | static constexpr int sum_idx = 0; |
80 | return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx) |
81 | && dst_type == data_type::f32 && params.gemm_applies_output_scales_ |
82 | && po.entry_[sum_idx].sum.zero_point == 0; |
83 | } |
84 | |
85 | template <impl::data_type_t dst_type> |
86 | status_t gemm_bf16_matmul_t<dst_type>::pd_t::check_and_configure_attributes() { |
87 | auto check_attr_scales = [&]() -> bool { |
88 | using namespace data_type; |
89 | const std::vector<int> supported_args |
90 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
91 | bool ok = attr()->scales_.has_default_values(supported_args); |
92 | for (int arg : supported_args) { |
93 | const auto &mask = attr()->scales_.get(arg).mask_; |
94 | if (arg == DNNL_ARG_WEIGHTS) |
95 | ok = ok && (mask == 0 || mask == (1 << (dst_md()->ndims - 1))); |
96 | else |
97 | ok = ok && (mask == 0); |
98 | } |
99 | |
100 | if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() |
101 | && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() |
102 | && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { |
103 | // This case requires scratchpad with unknown size |
104 | if (N() == DNNL_RUNTIME_DIM_VAL) return false; |
105 | } |
106 | return ok; |
107 | }; |
108 | |
109 | auto check_attr_post_ops = [&]() -> bool { |
110 | using namespace primitive_kind; |
111 | const auto &post_ops = attr()->post_ops_; |
112 | static const bcast_set_t enabled_bcast_strategy { |
113 | broadcasting_strategy_t::scalar, |
114 | broadcasting_strategy_t::per_oc, |
115 | broadcasting_strategy_t::per_oc_spatial, |
116 | broadcasting_strategy_t::per_mb_spatial, |
117 | broadcasting_strategy_t::per_mb_w, |
118 | broadcasting_strategy_t::per_w, |
119 | broadcasting_strategy_t::no_broadcast}; |
120 | const bool is_binary_po_per_oc |
121 | = binary_injector_utils::bcast_strategy_present( |
122 | binary_injector_utils::extract_bcast_strategies( |
123 | post_ops.entry_, dst_md()), |
124 | broadcasting_strategy_t::per_oc); |
125 | return cpu::inner_product_utils::post_ops_ok( |
126 | post_ops, dst_md(), enabled_bcast_strategy) |
127 | && IMPLICATION(is_binary_po_per_oc, |
128 | gemm_based::check_gemm_binary_per_oc_compatible_formats( |
129 | *this)); |
130 | }; |
131 | |
132 | // check basic attributes |
133 | if (!check_attr_scales()) return status::unimplemented; |
134 | |
135 | // set state |
136 | CHECK(params_.pp_attr_.copy_from(*attr())); |
137 | params_.gemm_applies_output_scales_ |
138 | = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias(); |
139 | |
140 | if (params_.gemm_applies_output_scales_) { |
141 | params_.pp_attr_.scales_.reset(DNNL_ARG_SRC); |
142 | params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS); |
143 | } |
144 | |
145 | // check post-ops |
146 | if (!check_attr_post_ops()) return status::unimplemented; |
147 | const bool sum_po_via_gemm_beta |
148 | = should_gemm_execute_sum_po(params_, dst_type); |
149 | // set state |
150 | params_.dst_is_acc_ = dst_type == data_type::f32 |
151 | && IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1, |
152 | sum_po_via_gemm_beta); |
153 | |
154 | if (sum_po_via_gemm_beta) { |
155 | // set state |
156 | const auto &po = params_.pp_attr_.post_ops_; |
157 | static constexpr int sum_idx = 0; |
158 | params_.gemm_beta_ = po.entry_[sum_idx].sum.scale; |
159 | } |
160 | |
161 | // set state |
162 | params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias() |
163 | || !params_.pp_attr_.has_default_values(); |
164 | |
165 | return status::success; |
166 | } |
167 | |
168 | template <impl::data_type_t dst_type> |
169 | bool gemm_bf16_matmul_t<dst_type>::should_skip_sum_po() const noexcept { |
170 | return should_gemm_execute_sum_po(pd()->params(), dst_type); |
171 | } |
172 | |
173 | template <impl::data_type_t dst_type> |
174 | status_t gemm_bf16_matmul_t<dst_type>::execute_ref( |
175 | const exec_ctx_t &ctx) const { |
176 | using namespace binary_injector_utils; |
177 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
178 | auto weights = CTX_IN_MEM(const weights_data_t *, DNNL_ARG_WEIGHTS); |
179 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
180 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
181 | const auto &po = this->pd()->attr()->post_ops_; |
182 | const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx); |
183 | |
184 | const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); |
185 | const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); |
186 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); |
187 | |
188 | const int ndims = pd()->ndims(); |
189 | |
190 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
191 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
192 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
193 | |
194 | auto scratchpad = ctx.get_scratchpad_grantor(); |
195 | const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, |
196 | dst_d.dims()[ndims - 1], pd()->attr()); |
197 | |
198 | if (src_d.has_zero_dim() || weights_d.has_zero_dim() |
199 | || dst_d.has_zero_dim()) |
200 | return status::success; |
201 | |
202 | matmul_helper_t helper(src_d, weights_d, dst_d); |
203 | const int batch_ndims = ndims - 2; |
204 | dim_t M = helper.M(); |
205 | const dim_t N = helper.N(); |
206 | const dim_t K = helper.K(); |
207 | const dim_t batch = helper.batch(); |
208 | const dim_t batch_without_dim0 |
209 | = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0; |
210 | const dim_t batch_without_dim01 |
211 | = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1; |
212 | const char transA = helper.transA(); |
213 | const char transB = helper.transB(); |
214 | const dim_t lda = helper.lda(); |
215 | const dim_t ldb = helper.ldb(); |
216 | const dim_t ldc = helper.ldc(); |
217 | const int nthr = pd()->nthr_; |
218 | |
219 | const gemm_based::params_t ¶ms = pd()->params(); |
220 | const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides() |
221 | ? helper.can_fuse_src_batch_dims() |
222 | : params.can_fuse_src_batch_dims_; |
223 | const dim_t acc_stride = gemm_based::get_scratchpad_size( |
224 | batch, M, N, can_fuse_src_batch_dims, nthr); |
225 | bool dst_is_acc = params.dst_is_acc_; |
226 | acc_data_t *acc = dst_is_acc |
227 | ? (acc_data_t *)dst |
228 | : ctx.get_scratchpad_grantor().template get<acc_data_t>( |
229 | memory_tracking::names::key_matmul_dst_in_acc_dt); |
230 | // case: dynamic sizes |
231 | bool need_free_acc = false; |
232 | if (acc == nullptr) { |
233 | acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride |
234 | * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr), |
235 | 64); |
236 | if (acc == nullptr) return status::out_of_memory; |
237 | need_free_acc = true; |
238 | } |
239 | |
240 | const float alpha = params.get_gemm_alpha(scales); |
241 | const float beta = params.gemm_beta_; |
242 | const dim_t acc_ldc = dst_is_acc ? ldc : N; |
243 | const int scale_idx_mult |
244 | = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ |
245 | == (1 << (ndims - 1)); |
246 | |
247 | std::atomic<status_t> st(status::success); |
248 | // use parallel over batch when binary po with channel bcast |
249 | // (except batch == 1) |
250 | bool is_binary_po_per_oc; |
251 | bool is_binary_po_per_oc_sp; |
252 | bool is_binary_po_channel_bcast; |
253 | std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp, |
254 | is_binary_po_channel_bcast) |
255 | = bcast_strategies_present_tup(po.entry_, pd()->dst_md(), |
256 | broadcasting_strategy_t::per_oc, |
257 | broadcasting_strategy_t::per_oc_spatial, |
258 | broadcasting_strategy_t::per_mb_spatial); |
259 | // if batched, parralel over batch for per_mb_sp and per_oc binary |
260 | // post-op broadcast |
261 | const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast |
262 | && IMPLICATION( |
263 | is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2); |
264 | const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims; |
265 | if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) { |
266 | const int src_mask |
267 | = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); |
268 | const int wei_mask |
269 | = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims); |
270 | const size_t bia_dt_size = !pd()->with_bias() |
271 | ? 0 |
272 | : types::data_type_size(pd()->weights_md(1)->data_type); |
273 | const size_t work_amount = (size_t)batch * M * N; |
274 | const size_t work_per_batch = (size_t)M * N; |
275 | |
276 | // NOTE: inside lambda, type cast variables captured by reference using |
277 | // either c-like "(type)var" or functional "type(var)" notation in order |
278 | // to avoid gcc bug with c++14 standard. Otherwise, capture by value. |
279 | parallel(nthr, [=, &st](int ithr, int nthr) { |
280 | size_t t_work_start {0}, t_work_end {0}; |
281 | balance211(work_amount, nthr, ithr, t_work_start, t_work_end); |
282 | |
283 | dim_t cur_b {0}, cur_m {0}, cur_n {0}; |
284 | dims_t s_dims_idx, w_dims_idx, d_dims_idx; |
285 | size_t i_work = t_work_start; |
286 | const bool reuse_acc = acc != (acc_data_t *)dst; |
287 | acc_data_t *curr_acc |
288 | = reuse_acc ? acc + ithr * acc_stride : nullptr; |
289 | |
290 | while (i_work < t_work_end) { |
291 | utils::nd_iterator_init( |
292 | i_work, cur_b, batch, cur_m, M, cur_n, N); |
293 | |
294 | utils::l_dims_by_l_offset( |
295 | d_dims_idx, i_work, dst_d.dims(), ndims); |
296 | utils::copy_dims_with_mask( |
297 | s_dims_idx, d_dims_idx, batch_ndims, src_mask); |
298 | s_dims_idx[ndims - 2] = cur_m; |
299 | s_dims_idx[ndims - 1] = 0; // k idx is always 0 |
300 | |
301 | utils::copy_dims_with_mask( |
302 | w_dims_idx, d_dims_idx, batch_ndims, wei_mask); |
303 | w_dims_idx[ndims - 2] = 0; // k idx is always 0 |
304 | w_dims_idx[ndims - 1] = cur_n; |
305 | const src_data_t *curr_src = src + src_d.off_v(s_dims_idx); |
306 | const weights_data_t *curr_weights |
307 | = weights + weights_d.off_v(w_dims_idx); |
308 | const dim_t dst_off = dst_d.off_v(d_dims_idx); |
309 | dst_data_t *curr_dst = dst + dst_off; |
310 | if (!reuse_acc) curr_acc = acc + dst_off; |
311 | dim_t gemm_M {0}, gemm_N {0}; |
312 | |
313 | size_t matrix_offset; |
314 | const size_t rem_work = t_work_end - i_work; |
315 | if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) { |
316 | // parallel over batch |
317 | gemm_M = M; |
318 | gemm_N = N; |
319 | matrix_offset = 0; |
320 | } else if (rem_work >= (size_t)N && cur_n == 0) { |
321 | // parallel over M |
322 | gemm_M = nstl::min( |
323 | (size_t)(M - cur_m), (size_t)(rem_work / N)); |
324 | gemm_N = N; |
325 | matrix_offset = cur_n + cur_m * N; |
326 | } else { |
327 | // parallel over N |
328 | gemm_M = 1; |
329 | gemm_N = nstl::min((size_t)(N - cur_n), rem_work); |
330 | matrix_offset = cur_n + cur_m * N; |
331 | } |
332 | |
333 | status_t st_thr = gemm_bf16bf16f32(&transB, &transA, &gemm_N, |
334 | &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda, |
335 | &beta, curr_acc, &acc_ldc); |
336 | if (st_thr != status::success) { |
337 | st = st_thr; |
338 | return; |
339 | } |
340 | |
341 | if (params.has_pp_kernel_) { |
342 | const float *pp_scales |
343 | = params.get_post_processing_scales(scales); |
344 | const size_t dst_logical_off = i_work; |
345 | const size_t dim1_off = helper.ndims() > 3 |
346 | ? ((cur_b % batch_without_dim0) |
347 | / batch_without_dim01) |
348 | : cur_m; |
349 | // offset for case with post-op broadcast_channel |
350 | const size_t matrix_per_first_batch_off = helper.ndims() > 3 |
351 | ? M * N * (cur_b / batch_without_dim0) |
352 | + matrix_offset |
353 | : 0; |
354 | const ptrdiff_t oc_off = i_work % N; |
355 | (*pp_kernel_)(curr_dst, curr_acc, |
356 | bias + oc_off * bia_dt_size, |
357 | pp_scales + oc_off * scale_idx_mult, dst_scales[0], |
358 | 0, dst_logical_off, dim1_off, gemm_M * gemm_N, |
359 | static_cast<size_t>(N), ldc, nullptr, |
360 | post_ops_binary_rhs_arg_vec.data(), dst, |
361 | matrix_per_first_batch_off, ctx, *pd()->dst_md()); |
362 | } |
363 | i_work += gemm_M * gemm_N; |
364 | } |
365 | }); |
366 | } else { |
367 | // collapse batch into M, if weights batch dimensions are broadcasted. |
368 | M = M * batch; |
369 | |
370 | st = gemm_bf16bf16f32(&transB, &transA, &N, &M, &K, &alpha, weights, |
371 | &ldb, src, &lda, &beta, acc, &acc_ldc); |
372 | |
373 | if (st == status::success && params.has_pp_kernel_) { |
374 | const bool force_sequential = pp_kernel_->sequential_kernel(); |
375 | const float *pp_scales = params.get_post_processing_scales(scales); |
376 | |
377 | parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) { |
378 | size_t start {}, end {}; |
379 | balance211((size_t)(M * N), nthr, ithr, start, end); |
380 | const size_t dst_logical_off = start; |
381 | const size_t dim1_off = start % N; |
382 | (*pp_kernel_)(dst, acc, bias, pp_scales, dst_scales[0], start, |
383 | dst_logical_off, dim1_off, end, (size_t)N, ldc, nullptr, |
384 | post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx, |
385 | *pd()->dst_md()); |
386 | }); |
387 | } |
388 | } |
389 | |
390 | if (need_free_acc) free(acc); |
391 | |
392 | return st; |
393 | } |
394 | |
395 | using namespace data_type; |
396 | template struct gemm_bf16_matmul_t<data_type::f32>; |
397 | template struct gemm_bf16_matmul_t<data_type::bf16>; |
398 | |
399 | } // namespace matmul |
400 | } // namespace cpu |
401 | } // namespace impl |
402 | } // namespace dnnl |
403 | |