1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <atomic> |
18 | |
19 | #include <assert.h> |
20 | #include <float.h> |
21 | #include <math.h> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/cpu_primitive.hpp" |
29 | #include "cpu/scale_utils.hpp" |
30 | |
31 | #include "cpu/gemm/gemm.hpp" |
32 | |
33 | #include "cpu/binary_injector_utils.hpp" |
34 | #include "cpu/matmul/gemm_f32_matmul.hpp" |
35 | #include "cpu/matmul/matmul_utils.hpp" |
36 | #include "cpu/scale_utils.hpp" |
37 | |
38 | namespace dnnl { |
39 | namespace impl { |
40 | namespace cpu { |
41 | namespace matmul { |
42 | |
43 | using namespace data_type; |
44 | |
45 | status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { |
46 | auto check_bias = [&]() -> bool { |
47 | return !with_bias() |
48 | || (weights_md(1)->data_type == f32 && is_bias_1xN()); |
49 | }; |
50 | |
51 | bool ok = !has_zero_dim_memory() && src_md()->data_type == src_type |
52 | && weights_md()->data_type == weights_type |
53 | && desc()->accum_data_type == acc_type |
54 | && dst_md()->data_type == dst_type && check_bias() |
55 | && attr()->has_default_values( |
56 | primitive_attr_t::skip_mask_t::scales_runtime |
57 | | primitive_attr_t::skip_mask_t::post_ops |
58 | | primitive_attr_t::skip_mask_t::sum_dt, |
59 | dst_type) |
60 | && attr()->post_ops_.check_sum_consistent_dt(dst_type) |
61 | && set_default_formats() |
62 | && attr_.set_default_formats(dst_md(0)) == status::success |
63 | && gemm_based::check_gemm_compatible_formats(*this); |
64 | |
65 | if (!ok) return status::unimplemented; |
66 | |
67 | if (!has_runtime_dims_or_strides()) |
68 | params_.can_fuse_src_batch_dims_ |
69 | = matmul_helper_t(src_md(), weights_md(), dst_md()) |
70 | .can_fuse_src_batch_dims(); |
71 | |
72 | CHECK(check_and_configure_attributes()); |
73 | |
74 | nthr_ = dnnl_get_max_threads(); |
75 | gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t), nthr_); |
76 | auto scratchpad = scratchpad_registry().registrar(); |
77 | book_precomputed_scales(scratchpad, attr()->scales_, N()); |
78 | |
79 | return status::success; |
80 | } |
81 | |
82 | static bool should_gemm_execute_sum_po( |
83 | const gemm_based::params_t ¶ms, data_type_t dst_dt) noexcept { |
84 | const auto &po = params.pp_attr_.post_ops_; |
85 | static constexpr int sum_idx = 0; |
86 | return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx) |
87 | && params.gemm_applies_output_scales_ |
88 | && po.entry_[sum_idx].sum.zero_point == 0 |
89 | && utils::one_of( |
90 | po.entry_[sum_idx].sum.dt, dst_dt, data_type::undef); |
91 | } |
92 | |
93 | status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() { |
94 | auto check_attr_scales = [&]() -> bool { |
95 | using namespace data_type; |
96 | const std::vector<int> supported_args |
97 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
98 | bool ok = attr()->scales_.has_default_values(supported_args); |
99 | for (int arg : supported_args) { |
100 | const auto &mask = attr()->scales_.get(arg).mask_; |
101 | if (arg == DNNL_ARG_WEIGHTS) |
102 | ok = ok && (mask == 0 || mask == (1 << (dst_md()->ndims - 1))); |
103 | else |
104 | ok = ok && (mask == 0); |
105 | } |
106 | |
107 | if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() |
108 | && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() |
109 | && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { |
110 | // This case requires scratchpad with unknown size |
111 | if (N() == DNNL_RUNTIME_DIM_VAL) return false; |
112 | } |
113 | return ok; |
114 | }; |
115 | |
116 | auto check_attr_post_ops = [&]() -> bool { |
117 | using namespace primitive_kind; |
118 | const auto &post_ops = attr()->post_ops_; |
119 | static const bcast_set_t enabled_bcast_strategy { |
120 | broadcasting_strategy_t::scalar, |
121 | broadcasting_strategy_t::per_oc, |
122 | broadcasting_strategy_t::per_oc_spatial, |
123 | broadcasting_strategy_t::per_mb_spatial, |
124 | broadcasting_strategy_t::per_mb_w, |
125 | broadcasting_strategy_t::per_w, |
126 | broadcasting_strategy_t::no_broadcast}; |
127 | const bool is_binary_po_per_oc |
128 | = binary_injector_utils::bcast_strategy_present( |
129 | binary_injector_utils::extract_bcast_strategies( |
130 | post_ops.entry_, dst_md()), |
131 | broadcasting_strategy_t::per_oc); |
132 | return cpu::inner_product_utils::post_ops_ok( |
133 | post_ops, dst_md(), enabled_bcast_strategy) |
134 | && IMPLICATION(is_binary_po_per_oc, |
135 | gemm_based::check_gemm_binary_per_oc_compatible_formats( |
136 | *this)); |
137 | }; |
138 | |
139 | // check basic attributes |
140 | if (!check_attr_scales()) return status::unimplemented; |
141 | |
142 | // set state |
143 | CHECK(params_.pp_attr_.copy_from(*attr())); |
144 | params_.gemm_applies_output_scales_ |
145 | = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias(); |
146 | if (params_.gemm_applies_output_scales_) { |
147 | params_.pp_attr_.scales_.reset(DNNL_ARG_SRC); |
148 | params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS); |
149 | } |
150 | |
151 | // check post-ops |
152 | if (!check_attr_post_ops()) return status::unimplemented; |
153 | const bool sum_po_via_gemm_beta |
154 | = should_gemm_execute_sum_po(params_, dst_md()->data_type); |
155 | // set state |
156 | params_.dst_is_acc_ |
157 | = IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1, |
158 | sum_po_via_gemm_beta); |
159 | |
160 | if (sum_po_via_gemm_beta) { |
161 | // set state |
162 | const auto &po = params_.pp_attr_.post_ops_; |
163 | static constexpr int sum_idx = 0; |
164 | params_.gemm_beta_ = po.entry_[sum_idx].sum.scale; |
165 | } |
166 | |
167 | // set state |
168 | params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias() |
169 | || !params_.pp_attr_.has_default_values(); |
170 | |
171 | return status::success; |
172 | } |
173 | |
174 | bool gemm_f32_matmul_t::should_skip_sum_po(data_type_t dst_dt) const noexcept { |
175 | return should_gemm_execute_sum_po(pd()->params(), dst_dt); |
176 | } |
177 | |
178 | status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { |
179 | using namespace binary_injector_utils; |
180 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
181 | auto weights = CTX_IN_MEM(const weights_data_t *, DNNL_ARG_WEIGHTS); |
182 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
183 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
184 | const auto &po = this->pd()->attr()->post_ops_; |
185 | const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx); |
186 | |
187 | const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); |
188 | const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); |
189 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); |
190 | |
191 | const int ndims = pd()->ndims(); |
192 | |
193 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
194 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
195 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
196 | |
197 | auto scratchpad = ctx.get_scratchpad_grantor(); |
198 | const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, |
199 | dst_d.dims()[ndims - 1], pd()->attr()); |
200 | |
201 | if (src_d.has_zero_dim() || weights_d.has_zero_dim() |
202 | || dst_d.has_zero_dim()) |
203 | return status::success; |
204 | |
205 | matmul_helper_t helper(src_d, weights_d, dst_d); |
206 | const int batch_ndims = ndims - 2; |
207 | dim_t M = helper.M(); |
208 | const dim_t N = helper.N(); |
209 | const dim_t K = helper.K(); |
210 | const dim_t batch = helper.batch(); |
211 | const dim_t batch_without_dim0 |
212 | = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0; |
213 | const dim_t batch_without_dim01 |
214 | = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1; |
215 | const char transA = helper.transA(); |
216 | const char transB = helper.transB(); |
217 | const dim_t lda = helper.lda(); |
218 | const dim_t ldb = helper.ldb(); |
219 | const dim_t ldc = helper.ldc(); |
220 | const int nthr = pd()->nthr_; |
221 | |
222 | const gemm_based::params_t ¶ms = pd()->params(); |
223 | const float alpha = params.get_gemm_alpha(scales); |
224 | const float beta = params.gemm_beta_; |
225 | const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides() |
226 | ? helper.can_fuse_src_batch_dims() |
227 | : params.can_fuse_src_batch_dims_; |
228 | const dim_t acc_stride = gemm_based::get_scratchpad_size( |
229 | batch, M, N, can_fuse_src_batch_dims, nthr); |
230 | bool dst_is_acc = params.dst_is_acc_; |
231 | acc_data_t *acc = dst_is_acc |
232 | ? (acc_data_t *)dst |
233 | : ctx.get_scratchpad_grantor().template get<acc_data_t>( |
234 | memory_tracking::names::key_matmul_dst_in_acc_dt); |
235 | // case: dynamic sizes |
236 | bool need_free_acc = false; |
237 | if (acc == nullptr) { |
238 | acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride |
239 | * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr), |
240 | 64); |
241 | if (acc == nullptr) return status::out_of_memory; |
242 | need_free_acc = true; |
243 | } |
244 | |
245 | const dim_t acc_ldc = dst_is_acc ? ldc : N; |
246 | const int scale_idx_mult |
247 | = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ |
248 | == (1 << (ndims - 1)); |
249 | |
250 | std::atomic<status_t> st(status::success); |
251 | // use parallel over batch when binary po with channel bcast |
252 | // (except batch == 1) |
253 | bool is_binary_po_per_oc = false; |
254 | bool is_binary_po_per_oc_sp = false; |
255 | bool is_binary_po_channel_bcast = false; |
256 | std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp, |
257 | is_binary_po_channel_bcast) |
258 | = bcast_strategies_present_tup(po.entry_, pd()->dst_md(), |
259 | broadcasting_strategy_t::per_oc, |
260 | broadcasting_strategy_t::per_oc_spatial, |
261 | broadcasting_strategy_t::per_mb_spatial); |
262 | // if batched, parralel over batch for per_mb_sp and per_oc binary |
263 | // post-op broadcast |
264 | const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast |
265 | && IMPLICATION( |
266 | is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2); |
267 | const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims; |
268 | if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) { |
269 | const int src_mask |
270 | = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); |
271 | const int wei_mask |
272 | = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims); |
273 | const size_t bia_dt_size = !pd()->with_bias() |
274 | ? 0 |
275 | : types::data_type_size(pd()->weights_md(1)->data_type); |
276 | const size_t work_amount = (size_t)batch * M * N; |
277 | const size_t work_per_batch = (size_t)M * N; |
278 | parallel(nthr, [&](int ithr, int nthr) { |
279 | size_t t_work_start {0}, t_work_end {0}; |
280 | balance211(work_amount, nthr, ithr, t_work_start, t_work_end); |
281 | |
282 | dim_t cur_b {0}, cur_m {0}, cur_n {0}; |
283 | dims_t s_dims_idx, w_dims_idx, d_dims_idx; |
284 | size_t i_work = t_work_start; |
285 | const bool reuse_acc = acc != (acc_data_t *)dst; |
286 | acc_data_t *curr_acc |
287 | = reuse_acc ? acc + ithr * acc_stride : nullptr; |
288 | |
289 | while (i_work < t_work_end) { |
290 | utils::nd_iterator_init( |
291 | i_work, cur_b, batch, cur_m, M, cur_n, N); |
292 | |
293 | utils::l_dims_by_l_offset( |
294 | d_dims_idx, i_work, dst_d.dims(), ndims); |
295 | |
296 | utils::copy_dims_with_mask( |
297 | s_dims_idx, d_dims_idx, batch_ndims, src_mask); |
298 | s_dims_idx[ndims - 2] = cur_m; |
299 | s_dims_idx[ndims - 1] = 0; // k idx is always 0 |
300 | |
301 | utils::copy_dims_with_mask( |
302 | w_dims_idx, d_dims_idx, batch_ndims, wei_mask); |
303 | w_dims_idx[ndims - 2] = 0; // k idx is always 0 |
304 | w_dims_idx[ndims - 1] = cur_n; |
305 | |
306 | const src_data_t *curr_src = src + src_d.off_v(s_dims_idx); |
307 | const weights_data_t *curr_weights |
308 | = weights + weights_d.off_v(w_dims_idx); |
309 | const dim_t dst_off = dst_d.off_v(d_dims_idx); |
310 | dst_data_t *curr_dst = dst + dst_off; |
311 | if (!reuse_acc) curr_acc = acc + dst_off; |
312 | dim_t gemm_M {0}, gemm_N {0}; |
313 | |
314 | size_t matrix_offset; |
315 | const size_t rem_work = t_work_end - i_work; |
316 | if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) { |
317 | // parallel over batch |
318 | gemm_M = M; |
319 | gemm_N = N; |
320 | matrix_offset = 0; |
321 | } else if (rem_work >= (size_t)N && cur_n == 0) { |
322 | // parallel over M |
323 | gemm_M = nstl::min( |
324 | (size_t)(M - cur_m), (size_t)(rem_work / N)); |
325 | gemm_N = N; |
326 | matrix_offset = cur_n + cur_m * N; |
327 | } else { |
328 | // parallel over N |
329 | gemm_M = 1; |
330 | gemm_N = nstl::min((size_t)(N - cur_n), rem_work); |
331 | matrix_offset = cur_n + cur_m * N; |
332 | } |
333 | |
334 | status_t st_thr = extended_sgemm(&transB, &transA, &gemm_N, |
335 | &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda, |
336 | &beta, curr_acc, &acc_ldc, nullptr, false); |
337 | if (st_thr != status::success) { |
338 | st = st_thr; |
339 | return; |
340 | } |
341 | |
342 | if (params.has_pp_kernel_) { |
343 | const float *pp_scales |
344 | = params.get_post_processing_scales(scales); |
345 | const size_t dst_logical_off = i_work; |
346 | const size_t dim1_off = helper.ndims() > 3 |
347 | ? ((cur_b % batch_without_dim0) |
348 | / batch_without_dim01) |
349 | : cur_m; |
350 | |
351 | // offset for case with post-op broadcast_channel |
352 | const size_t matrix_per_first_batch_off = helper.ndims() > 3 |
353 | ? M * N * (cur_b / batch_without_dim0) |
354 | + matrix_offset |
355 | : 0; |
356 | const ptrdiff_t oc_off = i_work % N; |
357 | (*pp_kernel_)(curr_dst, curr_acc, |
358 | bias + oc_off * bia_dt_size, |
359 | pp_scales + oc_off * scale_idx_mult, dst_scales[0], |
360 | 0, dst_logical_off, dim1_off, gemm_M * gemm_N, |
361 | static_cast<size_t>(N), ldc, nullptr, |
362 | post_ops_binary_rhs_arg_vec.data(), dst, |
363 | matrix_per_first_batch_off, ctx, *pd()->dst_md()); |
364 | } |
365 | i_work += gemm_M * gemm_N; |
366 | } |
367 | }); |
368 | } else { |
369 | // collapse batch into M, if weights batch dimensions are broadcasted. |
370 | M = batch * M; |
371 | |
372 | st = extended_sgemm(&transB, &transA, &N, &M, &K, &alpha, weights, &ldb, |
373 | src, &lda, &beta, acc, &acc_ldc, nullptr, false); |
374 | |
375 | if (st == status::success && params.has_pp_kernel_) { |
376 | const bool force_sequential = pp_kernel_->sequential_kernel(); |
377 | const float *pp_scales = params.get_post_processing_scales(scales); |
378 | parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) { |
379 | size_t start {}, end {}; |
380 | balance211((size_t)(M * N), nthr, ithr, start, end); |
381 | const size_t dst_logical_off = start; |
382 | const size_t dst_start_row_idx = start % N; |
383 | (*pp_kernel_)(dst, acc, bias, pp_scales, dst_scales[0], start, |
384 | dst_logical_off, dst_start_row_idx, end, (size_t)N, ldc, |
385 | nullptr, post_ops_binary_rhs_arg_vec.data(), dst, 0, |
386 | ctx, *pd()->dst_md()); |
387 | }); |
388 | } |
389 | } |
390 | |
391 | if (need_free_acc) free(acc); |
392 | |
393 | return st; |
394 | } |
395 | |
396 | } // namespace matmul |
397 | } // namespace cpu |
398 | } // namespace impl |
399 | } // namespace dnnl |
400 | |