1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_MATMUL_GEMM_BASED_COMMON_HPP |
18 | #define CPU_MATMUL_GEMM_BASED_COMMON_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/primitive_attr.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | |
27 | #include "cpu/matmul/cpu_matmul_pd.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace matmul { |
33 | namespace gemm_based { |
34 | |
35 | struct params_t { |
36 | // indicates if an auxiliary array for intermediate computations is not |
37 | // required |
38 | bool dst_is_acc_; |
39 | |
40 | // indicates if output scales from attributes are applied |
41 | // by gemm (alpha parameter) or post-op kernel (pp_kernel_) |
42 | bool gemm_applies_output_scales_ = false; |
43 | |
44 | // sum post-op scaling factor that is fused into gemm |
45 | float gemm_beta_ = 0.f; |
46 | |
47 | // indicates if a special post processing kernel |
48 | // should be invoked after gemm |
49 | bool has_pp_kernel_ = false; |
50 | |
51 | // indicates if src batch dims can be fused into M, so that a single |
52 | // GeMM call can be made |
53 | bool can_fuse_src_batch_dims_ = false; |
54 | |
55 | float default_pp_scales_ = 1.0f; |
56 | |
57 | // an attribute for post processing kernel |
58 | primitive_attr_t pp_attr_; |
59 | |
60 | // auxiliary functions |
61 | |
62 | // returns gemm alpha parameter (a single value for now) |
63 | float get_gemm_alpha(const float *primitive_scales) const { |
64 | return gemm_applies_output_scales_ ? primitive_scales[0] : 1.f; |
65 | } |
66 | |
67 | // returns scaling factors for post processing kernel |
68 | const float *get_post_processing_scales( |
69 | const float *primitive_scales) const { |
70 | return gemm_applies_output_scales_ ? &default_pp_scales_ |
71 | : primitive_scales; |
72 | } |
73 | }; |
74 | |
75 | inline bool check_gemm_compatible_formats(const matmul_pd_t &pd) { |
76 | |
77 | const memory_desc_wrapper dst_d(pd.dst_md()); |
78 | const int ndims = dst_d.ndims(); |
79 | |
80 | auto check_input_format = [=](const memory_desc_t *md) { |
81 | memory_desc_wrapper mdw(md); |
82 | |
83 | if (!mdw.is_plain()) return false; |
84 | |
85 | const dims_t &strides = mdw.blocking_desc().strides; |
86 | |
87 | // disable md with zero stride for a particular dimension |
88 | for (int dim = 0; dim < ndims; ++dim) |
89 | if (strides[dim] == 0) return false; |
90 | |
91 | // for GeMM atleast one of the two innermost axes must be contiguous |
92 | return utils::one_of(1, strides[ndims - 1], strides[ndims - 2]); |
93 | }; |
94 | |
95 | bool ok = check_input_format(pd.src_md()) |
96 | && check_input_format(pd.weights_md()) && dst_d.is_plain() |
97 | && dst_d.blocking_desc().strides[ndims - 1] == 1; |
98 | |
99 | return ok; |
100 | } |
101 | |
102 | inline bool check_gemm_binary_per_oc_compatible_formats(const matmul_pd_t &pd) { |
103 | const memory_desc_wrapper dst_d(pd.dst_md()); |
104 | const dims_t &strides = dst_d.blocking_desc().strides; |
105 | const dims_t &dims = dst_d.dims(); |
106 | const int ndims = dst_d.ndims(); |
107 | |
108 | // check d, h, w... (b2, m, n... for matmul) dimensions are continuous |
109 | bool ok = true; |
110 | for (int i = 2; i < ndims - 1; i++) |
111 | ok = ok && strides[i] == strides[i + 1] * dims[i + 1]; |
112 | // only allowed for nchw and nhwc (b0xb1xMxN or b0xMxNxb1 for matmul) |
113 | return ok && strides[0] == utils::array_product(dims + 1, ndims - 1); |
114 | } |
115 | |
116 | inline size_t get_scratchpad_size(const dim_t batch, dim_t M, const dim_t N, |
117 | const bool can_fuse_src_batch_dims, const int nthr) { |
118 | assert(batch > 0); |
119 | assert(M > 0); |
120 | assert(N > 0); |
121 | size_t buffer_size; |
122 | if (can_fuse_src_batch_dims || batch == 1) { |
123 | buffer_size = (size_t)batch * M * N; |
124 | } else { |
125 | const size_t work_per_thr = utils::div_up((size_t)batch * M * N, nthr); |
126 | if (work_per_thr >= (size_t)N) { |
127 | buffer_size = nstl::min<size_t>( |
128 | (size_t)M * N, utils::rnd_dn(work_per_thr, N)); |
129 | } else { |
130 | buffer_size = work_per_thr; |
131 | } |
132 | } |
133 | return utils::rnd_up(buffer_size, 64); |
134 | } |
135 | |
136 | inline void book_acc_scratchpad(matmul_pd_t &pd, const params_t ¶ms, |
137 | size_t sizeof_acc_data, const int nthr) { |
138 | |
139 | if (!params.dst_is_acc_ |
140 | && !memory_desc_wrapper(pd.dst_md()).has_runtime_dims()) { |
141 | const size_t buffer_size = get_scratchpad_size(pd.batch(), pd.M(), |
142 | pd.N(), params.can_fuse_src_batch_dims_, nthr); |
143 | const size_t sp_size = params.can_fuse_src_batch_dims_ |
144 | ? buffer_size |
145 | : buffer_size * nthr; |
146 | auto scratchpad = pd.scratchpad_registry().registrar(); |
147 | scratchpad.book(memory_tracking::names::key_matmul_dst_in_acc_dt, |
148 | sp_size, sizeof_acc_data); |
149 | } |
150 | } |
151 | |
152 | } // namespace gemm_based |
153 | } // namespace matmul |
154 | } // namespace cpu |
155 | } // namespace impl |
156 | } // namespace dnnl |
157 | |
158 | #endif |
159 | |