1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_MATMUL_GEMM_BASED_COMMON_HPP
18#define CPU_MATMUL_GEMM_BASED_COMMON_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/primitive_attr.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/matmul/cpu_matmul_pd.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace matmul {
33namespace gemm_based {
34
35struct params_t {
36 // indicates if an auxiliary array for intermediate computations is not
37 // required
38 bool dst_is_acc_;
39
40 // indicates if output scales from attributes are applied
41 // by gemm (alpha parameter) or post-op kernel (pp_kernel_)
42 bool gemm_applies_output_scales_ = false;
43
44 // sum post-op scaling factor that is fused into gemm
45 float gemm_beta_ = 0.f;
46
47 // indicates if a special post processing kernel
48 // should be invoked after gemm
49 bool has_pp_kernel_ = false;
50
51 // indicates if src batch dims can be fused into M, so that a single
52 // GeMM call can be made
53 bool can_fuse_src_batch_dims_ = false;
54
55 float default_pp_scales_ = 1.0f;
56
57 // an attribute for post processing kernel
58 primitive_attr_t pp_attr_;
59
60 // auxiliary functions
61
62 // returns gemm alpha parameter (a single value for now)
63 float get_gemm_alpha(const float *primitive_scales) const {
64 return gemm_applies_output_scales_ ? primitive_scales[0] : 1.f;
65 }
66
67 // returns scaling factors for post processing kernel
68 const float *get_post_processing_scales(
69 const float *primitive_scales) const {
70 return gemm_applies_output_scales_ ? &default_pp_scales_
71 : primitive_scales;
72 }
73};
74
75inline bool check_gemm_compatible_formats(const matmul_pd_t &pd) {
76
77 const memory_desc_wrapper dst_d(pd.dst_md());
78 const int ndims = dst_d.ndims();
79
80 auto check_input_format = [=](const memory_desc_t *md) {
81 memory_desc_wrapper mdw(md);
82
83 if (!mdw.is_plain()) return false;
84
85 const dims_t &strides = mdw.blocking_desc().strides;
86
87 // disable md with zero stride for a particular dimension
88 for (int dim = 0; dim < ndims; ++dim)
89 if (strides[dim] == 0) return false;
90
91 // for GeMM atleast one of the two innermost axes must be contiguous
92 return utils::one_of(1, strides[ndims - 1], strides[ndims - 2]);
93 };
94
95 bool ok = check_input_format(pd.src_md())
96 && check_input_format(pd.weights_md()) && dst_d.is_plain()
97 && dst_d.blocking_desc().strides[ndims - 1] == 1;
98
99 return ok;
100}
101
102inline bool check_gemm_binary_per_oc_compatible_formats(const matmul_pd_t &pd) {
103 const memory_desc_wrapper dst_d(pd.dst_md());
104 const dims_t &strides = dst_d.blocking_desc().strides;
105 const dims_t &dims = dst_d.dims();
106 const int ndims = dst_d.ndims();
107
108 // check d, h, w... (b2, m, n... for matmul) dimensions are continuous
109 bool ok = true;
110 for (int i = 2; i < ndims - 1; i++)
111 ok = ok && strides[i] == strides[i + 1] * dims[i + 1];
112 // only allowed for nchw and nhwc (b0xb1xMxN or b0xMxNxb1 for matmul)
113 return ok && strides[0] == utils::array_product(dims + 1, ndims - 1);
114}
115
116inline size_t get_scratchpad_size(const dim_t batch, dim_t M, const dim_t N,
117 const bool can_fuse_src_batch_dims, const int nthr) {
118 assert(batch > 0);
119 assert(M > 0);
120 assert(N > 0);
121 size_t buffer_size;
122 if (can_fuse_src_batch_dims || batch == 1) {
123 buffer_size = (size_t)batch * M * N;
124 } else {
125 const size_t work_per_thr = utils::div_up((size_t)batch * M * N, nthr);
126 if (work_per_thr >= (size_t)N) {
127 buffer_size = nstl::min<size_t>(
128 (size_t)M * N, utils::rnd_dn(work_per_thr, N));
129 } else {
130 buffer_size = work_per_thr;
131 }
132 }
133 return utils::rnd_up(buffer_size, 64);
134}
135
136inline void book_acc_scratchpad(matmul_pd_t &pd, const params_t &params,
137 size_t sizeof_acc_data, const int nthr) {
138
139 if (!params.dst_is_acc_
140 && !memory_desc_wrapper(pd.dst_md()).has_runtime_dims()) {
141 const size_t buffer_size = get_scratchpad_size(pd.batch(), pd.M(),
142 pd.N(), params.can_fuse_src_batch_dims_, nthr);
143 const size_t sp_size = params.can_fuse_src_batch_dims_
144 ? buffer_size
145 : buffer_size * nthr;
146 auto scratchpad = pd.scratchpad_registry().registrar();
147 scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt,
148 sp_size, sizeof_acc_data);
149 }
150}
151
152} // namespace gemm_based
153} // namespace matmul
154} // namespace cpu
155} // namespace impl
156} // namespace dnnl
157
158#endif
159