1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <atomic>
18
19#include <assert.h>
20#include <float.h>
21#include <math.h>
22
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/cpu_primitive.hpp"
29#include "cpu/scale_utils.hpp"
30
31#include "cpu/gemm/gemm.hpp"
32
33#include "cpu/binary_injector_utils.hpp"
34#include "cpu/matmul/gemm_f32_matmul.hpp"
35#include "cpu/matmul/matmul_utils.hpp"
36#include "cpu/scale_utils.hpp"
37
38namespace dnnl {
39namespace impl {
40namespace cpu {
41namespace matmul {
42
43using namespace data_type;
44
45status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) {
46 auto check_bias = [&]() -> bool {
47 return !with_bias()
48 || (weights_md(1)->data_type == f32 && is_bias_1xN());
49 };
50
51 bool ok = !has_zero_dim_memory() && src_md()->data_type == src_type
52 && weights_md()->data_type == weights_type
53 && desc()->accum_data_type == acc_type
54 && dst_md()->data_type == dst_type && check_bias()
55 && attr()->has_default_values(
56 primitive_attr_t::skip_mask_t::scales_runtime
57 | primitive_attr_t::skip_mask_t::post_ops
58 | primitive_attr_t::skip_mask_t::sum_dt,
59 dst_type)
60 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
61 && set_default_formats()
62 && attr_.set_default_formats(dst_md(0)) == status::success
63 && gemm_based::check_gemm_compatible_formats(*this);
64
65 if (!ok) return status::unimplemented;
66
67 if (!has_runtime_dims_or_strides())
68 params_.can_fuse_src_batch_dims_
69 = matmul_helper_t(src_md(), weights_md(), dst_md())
70 .can_fuse_src_batch_dims();
71
72 CHECK(check_and_configure_attributes());
73
74 nthr_ = dnnl_get_max_threads();
75 gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t), nthr_);
76 auto scratchpad = scratchpad_registry().registrar();
77 book_precomputed_scales(scratchpad, attr()->scales_, N());
78
79 return status::success;
80}
81
82static bool should_gemm_execute_sum_po(
83 const gemm_based::params_t &params, data_type_t dst_dt) noexcept {
84 const auto &po = params.pp_attr_.post_ops_;
85 static constexpr int sum_idx = 0;
86 return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx)
87 && params.gemm_applies_output_scales_
88 && po.entry_[sum_idx].sum.zero_point == 0
89 && utils::one_of(
90 po.entry_[sum_idx].sum.dt, dst_dt, data_type::undef);
91}
92
93status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() {
94 auto check_attr_scales = [&]() -> bool {
95 using namespace data_type;
96 const std::vector<int> supported_args
97 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
98 bool ok = attr()->scales_.has_default_values(supported_args);
99 for (int arg : supported_args) {
100 const auto &mask = attr()->scales_.get(arg).mask_;
101 if (arg == DNNL_ARG_WEIGHTS)
102 ok = ok && (mask == 0 || mask == (1 << (dst_md()->ndims - 1)));
103 else
104 ok = ok && (mask == 0);
105 }
106
107 if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values()
108 && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()
109 && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) {
110 // This case requires scratchpad with unknown size
111 if (N() == DNNL_RUNTIME_DIM_VAL) return false;
112 }
113 return ok;
114 };
115
116 auto check_attr_post_ops = [&]() -> bool {
117 using namespace primitive_kind;
118 const auto &post_ops = attr()->post_ops_;
119 static const bcast_set_t enabled_bcast_strategy {
120 broadcasting_strategy_t::scalar,
121 broadcasting_strategy_t::per_oc,
122 broadcasting_strategy_t::per_oc_spatial,
123 broadcasting_strategy_t::per_mb_spatial,
124 broadcasting_strategy_t::per_mb_w,
125 broadcasting_strategy_t::per_w,
126 broadcasting_strategy_t::no_broadcast};
127 const bool is_binary_po_per_oc
128 = binary_injector_utils::bcast_strategy_present(
129 binary_injector_utils::extract_bcast_strategies(
130 post_ops.entry_, dst_md()),
131 broadcasting_strategy_t::per_oc);
132 return cpu::inner_product_utils::post_ops_ok(
133 post_ops, dst_md(), enabled_bcast_strategy)
134 && IMPLICATION(is_binary_po_per_oc,
135 gemm_based::check_gemm_binary_per_oc_compatible_formats(
136 *this));
137 };
138
139 // check basic attributes
140 if (!check_attr_scales()) return status::unimplemented;
141
142 // set state
143 CHECK(params_.pp_attr_.copy_from(*attr()));
144 params_.gemm_applies_output_scales_
145 = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias();
146 if (params_.gemm_applies_output_scales_) {
147 params_.pp_attr_.scales_.reset(DNNL_ARG_SRC);
148 params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS);
149 }
150
151 // check post-ops
152 if (!check_attr_post_ops()) return status::unimplemented;
153 const bool sum_po_via_gemm_beta
154 = should_gemm_execute_sum_po(params_, dst_md()->data_type);
155 // set state
156 params_.dst_is_acc_
157 = IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1,
158 sum_po_via_gemm_beta);
159
160 if (sum_po_via_gemm_beta) {
161 // set state
162 const auto &po = params_.pp_attr_.post_ops_;
163 static constexpr int sum_idx = 0;
164 params_.gemm_beta_ = po.entry_[sum_idx].sum.scale;
165 }
166
167 // set state
168 params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias()
169 || !params_.pp_attr_.has_default_values();
170
171 return status::success;
172}
173
174bool gemm_f32_matmul_t::should_skip_sum_po(data_type_t dst_dt) const noexcept {
175 return should_gemm_execute_sum_po(pd()->params(), dst_dt);
176}
177
178status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
179 using namespace binary_injector_utils;
180 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
181 auto weights = CTX_IN_MEM(const weights_data_t *, DNNL_ARG_WEIGHTS);
182 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
183 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
184 const auto &po = this->pd()->attr()->post_ops_;
185 const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx);
186
187 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
188 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
189 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
190
191 const int ndims = pd()->ndims();
192
193 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
194 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
195 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
196
197 auto scratchpad = ctx.get_scratchpad_grantor();
198 const float *scales = precompute_scales(scratchpad, src_scales, wei_scales,
199 dst_d.dims()[ndims - 1], pd()->attr());
200
201 if (src_d.has_zero_dim() || weights_d.has_zero_dim()
202 || dst_d.has_zero_dim())
203 return status::success;
204
205 matmul_helper_t helper(src_d, weights_d, dst_d);
206 const int batch_ndims = ndims - 2;
207 dim_t M = helper.M();
208 const dim_t N = helper.N();
209 const dim_t K = helper.K();
210 const dim_t batch = helper.batch();
211 const dim_t batch_without_dim0
212 = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0;
213 const dim_t batch_without_dim01
214 = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1;
215 const char transA = helper.transA();
216 const char transB = helper.transB();
217 const dim_t lda = helper.lda();
218 const dim_t ldb = helper.ldb();
219 const dim_t ldc = helper.ldc();
220 const int nthr = pd()->nthr_;
221
222 const gemm_based::params_t &params = pd()->params();
223 const float alpha = params.get_gemm_alpha(scales);
224 const float beta = params.gemm_beta_;
225 const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
226 ? helper.can_fuse_src_batch_dims()
227 : params.can_fuse_src_batch_dims_;
228 const dim_t acc_stride = gemm_based::get_scratchpad_size(
229 batch, M, N, can_fuse_src_batch_dims, nthr);
230 bool dst_is_acc = params.dst_is_acc_;
231 acc_data_t *acc = dst_is_acc
232 ? (acc_data_t *)dst
233 : ctx.get_scratchpad_grantor().template get<acc_data_t>(
234 memory_tracking::names::key_matmul_dst_in_acc_dt);
235 // case: dynamic sizes
236 bool need_free_acc = false;
237 if (acc == nullptr) {
238 acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride
239 * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
240 64);
241 if (acc == nullptr) return status::out_of_memory;
242 need_free_acc = true;
243 }
244
245 const dim_t acc_ldc = dst_is_acc ? ldc : N;
246 const int scale_idx_mult
247 = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_
248 == (1 << (ndims - 1));
249
250 std::atomic<status_t> st(status::success);
251 // use parallel over batch when binary po with channel bcast
252 // (except batch == 1)
253 bool is_binary_po_per_oc = false;
254 bool is_binary_po_per_oc_sp = false;
255 bool is_binary_po_channel_bcast = false;
256 std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
257 is_binary_po_channel_bcast)
258 = bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
259 broadcasting_strategy_t::per_oc,
260 broadcasting_strategy_t::per_oc_spatial,
261 broadcasting_strategy_t::per_mb_spatial);
262 // if batched, parralel over batch for per_mb_sp and per_oc binary
263 // post-op broadcast
264 const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
265 && IMPLICATION(
266 is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
267 const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
268 if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
269 const int src_mask
270 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
271 const int wei_mask
272 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
273 const size_t bia_dt_size = !pd()->with_bias()
274 ? 0
275 : types::data_type_size(pd()->weights_md(1)->data_type);
276 const size_t work_amount = (size_t)batch * M * N;
277 const size_t work_per_batch = (size_t)M * N;
278 parallel(nthr, [&](int ithr, int nthr) {
279 size_t t_work_start {0}, t_work_end {0};
280 balance211(work_amount, nthr, ithr, t_work_start, t_work_end);
281
282 dim_t cur_b {0}, cur_m {0}, cur_n {0};
283 dims_t s_dims_idx, w_dims_idx, d_dims_idx;
284 size_t i_work = t_work_start;
285 const bool reuse_acc = acc != (acc_data_t *)dst;
286 acc_data_t *curr_acc
287 = reuse_acc ? acc + ithr * acc_stride : nullptr;
288
289 while (i_work < t_work_end) {
290 utils::nd_iterator_init(
291 i_work, cur_b, batch, cur_m, M, cur_n, N);
292
293 utils::l_dims_by_l_offset(
294 d_dims_idx, i_work, dst_d.dims(), ndims);
295
296 utils::copy_dims_with_mask(
297 s_dims_idx, d_dims_idx, batch_ndims, src_mask);
298 s_dims_idx[ndims - 2] = cur_m;
299 s_dims_idx[ndims - 1] = 0; // k idx is always 0
300
301 utils::copy_dims_with_mask(
302 w_dims_idx, d_dims_idx, batch_ndims, wei_mask);
303 w_dims_idx[ndims - 2] = 0; // k idx is always 0
304 w_dims_idx[ndims - 1] = cur_n;
305
306 const src_data_t *curr_src = src + src_d.off_v(s_dims_idx);
307 const weights_data_t *curr_weights
308 = weights + weights_d.off_v(w_dims_idx);
309 const dim_t dst_off = dst_d.off_v(d_dims_idx);
310 dst_data_t *curr_dst = dst + dst_off;
311 if (!reuse_acc) curr_acc = acc + dst_off;
312 dim_t gemm_M {0}, gemm_N {0};
313
314 size_t matrix_offset;
315 const size_t rem_work = t_work_end - i_work;
316 if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) {
317 // parallel over batch
318 gemm_M = M;
319 gemm_N = N;
320 matrix_offset = 0;
321 } else if (rem_work >= (size_t)N && cur_n == 0) {
322 // parallel over M
323 gemm_M = nstl::min(
324 (size_t)(M - cur_m), (size_t)(rem_work / N));
325 gemm_N = N;
326 matrix_offset = cur_n + cur_m * N;
327 } else {
328 // parallel over N
329 gemm_M = 1;
330 gemm_N = nstl::min((size_t)(N - cur_n), rem_work);
331 matrix_offset = cur_n + cur_m * N;
332 }
333
334 status_t st_thr = extended_sgemm(&transB, &transA, &gemm_N,
335 &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda,
336 &beta, curr_acc, &acc_ldc, nullptr, false);
337 if (st_thr != status::success) {
338 st = st_thr;
339 return;
340 }
341
342 if (params.has_pp_kernel_) {
343 const float *pp_scales
344 = params.get_post_processing_scales(scales);
345 const size_t dst_logical_off = i_work;
346 const size_t dim1_off = helper.ndims() > 3
347 ? ((cur_b % batch_without_dim0)
348 / batch_without_dim01)
349 : cur_m;
350
351 // offset for case with post-op broadcast_channel
352 const size_t matrix_per_first_batch_off = helper.ndims() > 3
353 ? M * N * (cur_b / batch_without_dim0)
354 + matrix_offset
355 : 0;
356 const ptrdiff_t oc_off = i_work % N;
357 (*pp_kernel_)(curr_dst, curr_acc,
358 bias + oc_off * bia_dt_size,
359 pp_scales + oc_off * scale_idx_mult, dst_scales[0],
360 0, dst_logical_off, dim1_off, gemm_M * gemm_N,
361 static_cast<size_t>(N), ldc, nullptr,
362 post_ops_binary_rhs_arg_vec.data(), dst,
363 matrix_per_first_batch_off, ctx, *pd()->dst_md());
364 }
365 i_work += gemm_M * gemm_N;
366 }
367 });
368 } else {
369 // collapse batch into M, if weights batch dimensions are broadcasted.
370 M = batch * M;
371
372 st = extended_sgemm(&transB, &transA, &N, &M, &K, &alpha, weights, &ldb,
373 src, &lda, &beta, acc, &acc_ldc, nullptr, false);
374
375 if (st == status::success && params.has_pp_kernel_) {
376 const bool force_sequential = pp_kernel_->sequential_kernel();
377 const float *pp_scales = params.get_post_processing_scales(scales);
378 parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) {
379 size_t start {}, end {};
380 balance211((size_t)(M * N), nthr, ithr, start, end);
381 const size_t dst_logical_off = start;
382 const size_t dst_start_row_idx = start % N;
383 (*pp_kernel_)(dst, acc, bias, pp_scales, dst_scales[0], start,
384 dst_logical_off, dst_start_row_idx, end, (size_t)N, ldc,
385 nullptr, post_ops_binary_rhs_arg_vec.data(), dst, 0,
386 ctx, *pd()->dst_md());
387 });
388 }
389 }
390
391 if (need_free_acc) free(acc);
392
393 return st;
394}
395
396} // namespace matmul
397} // namespace cpu
398} // namespace impl
399} // namespace dnnl
400