1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_KERNEL_H_ |
17 | #define RUY_RUY_KERNEL_H_ |
18 | |
19 | #include "ruy/kernel_common.h" |
20 | #include "ruy/mul_params.h" |
21 | #include "ruy/platform.h" |
22 | #include "ruy/trace.h" |
23 | |
24 | // IWYU pragma: begin_exports |
25 | #if RUY_PLATFORM_NEON |
26 | #include "ruy/kernel_arm.h" |
27 | #elif RUY_PLATFORM_X86 |
28 | #include "ruy/kernel_x86.h" |
29 | #endif |
30 | // IWYU pragma: end_exports |
31 | |
32 | namespace ruy { |
33 | |
34 | // KernelArgs is a helper to access the template parameter values from a Kernel |
35 | // template instantiation. |
36 | template <typename KernelType> |
37 | struct KernelArgs {}; |
38 | |
39 | template <Path tPath, typename tLhsScalar, typename tRhsScalar, |
40 | typename tAccumScalar, typename tDstScalar> |
41 | struct KernelArgs< |
42 | Kernel<tPath, tLhsScalar, tRhsScalar, tAccumScalar, tDstScalar>> { |
43 | static constexpr Path kPath = tPath; |
44 | using LhsScalar = tLhsScalar; |
45 | using RhsScalar = tRhsScalar; |
46 | using AccumScalar = tAccumScalar; |
47 | using DstScalar = tDstScalar; |
48 | }; |
49 | |
50 | // RunKernel::Run() is the only place that directly invokes Kernel::Run(). |
51 | // It performs the types un-erasure, and factoring all Kernel::Run() calls |
52 | // through this function also gives a single place where to conditionally |
53 | // implement RUY_OPT(FAT_KERNEL). This should be a function but is a class to |
54 | // hide and share some boilerplate (see the member types, and the RunTyped |
55 | // method also using them). |
56 | template <typename KernelType> |
57 | class RunKernel final { |
58 | public: |
59 | static void Run(Tuning tuning, const SidePair<PEMat>& src, |
60 | const void* mul_params, const SidePair<int>& start, |
61 | const SidePair<int>& end, EMat* dst) { |
62 | RUY_TRACE_SCOPE_NAME("RunKernel" ); |
63 | const auto& unerased_lhs = UneraseType<LhsScalar>(src[Side::kLhs]); |
64 | const auto& unerased_rhs = UneraseType<RhsScalar>(src[Side::kRhs]); |
65 | auto unerased_dst = UneraseType<DstScalar>(*dst); |
66 | RUY_TRACE_INFO(RUN_KERNEL); |
67 | RunTyped(tuning, unerased_lhs, unerased_rhs, |
68 | *static_cast<const MulParamsType*>(mul_params), start, end, |
69 | &unerased_dst); |
70 | } |
71 | |
72 | private: |
73 | using Args = KernelArgs<KernelType>; |
74 | using LhsScalar = typename Args::LhsScalar; |
75 | using RhsScalar = typename Args::RhsScalar; |
76 | using AccumScalar = typename Args::AccumScalar; |
77 | using DstScalar = typename Args::DstScalar; |
78 | using MulParamsType = MulParams<AccumScalar, DstScalar>; |
79 | static void RunTyped(Tuning tuning, const PMat<LhsScalar>& lhs, |
80 | const PMat<RhsScalar>& rhs, |
81 | const MulParamsType& mul_params, |
82 | const SidePair<int>& start, const SidePair<int>& end, |
83 | Mat<DstScalar>* dst) { |
84 | const int start_row = start[Side::kLhs]; |
85 | const int start_col = start[Side::kRhs]; |
86 | const int end_row = end[Side::kLhs]; |
87 | const int end_col = end[Side::kRhs]; |
88 | KernelType kernel(tuning); |
89 | using LhsLayout = typename KernelType::LhsLayout; |
90 | using RhsLayout = typename KernelType::RhsLayout; |
91 | // This is a good place to validate kernel layouts. The Kernel class |
92 | // template itself isn't a good place to do that because it has |
93 | // specializations. |
94 | // The kRows of both sides have to match: in TrMul, kRows is the depth |
95 | // dimension, on which LHS and RHS have to agree for the matrix |
96 | // multiplication to be defined at all, so requiring the corresponding |
97 | // dimension of the kernel layouts to also match is reasonable. If it didn't |
98 | // match, then the packed matrices could have mismatching depth dimensions |
99 | // even with the source matrices agreeing. |
100 | static_assert(LhsLayout::kRows == RhsLayout::kRows, "" ); |
101 | // The kernel layouts have to be power-of-two. This simplifies BlockMap |
102 | // logic considerably. This also avoids leaking fine performance |
103 | // optimization details up the stack. For instance, if one of the dimensions |
104 | // were 6, then users might notice that optimal performance is achieved with |
105 | // matrix dimensions that are multiples of 6, and might start contorting |
106 | // their own application code to match that requirement, in a way that would |
107 | // not be future-proof. |
108 | static_assert(is_pot(LhsLayout::kRows), "" ); |
109 | static_assert(is_pot(LhsLayout::kCols), "" ); |
110 | static_assert(is_pot(RhsLayout::kRows), "" ); |
111 | static_assert(is_pot(RhsLayout::kCols), "" ); |
112 | // end_row and end_col may be larger than dst dimensions. |
113 | // that is because kernels write directly to the destination matrix, whose |
114 | // dimensions may not be a multiple of the kernel dimensions, and we try to |
115 | // keep this annoyance localized as an implementation detail in kernels, |
116 | // by allowing to pass rounded-up values down as far as possible. |
117 | // These assertions encode the contract. |
118 | RUY_DCHECK_LE(0, start_row); |
119 | RUY_DCHECK_LE(start_row, end_row); |
120 | RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); |
121 | RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); |
122 | RUY_DCHECK_LE(0, start_col); |
123 | RUY_DCHECK_LE(start_col, end_col); |
124 | RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); |
125 | RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); |
126 | #if RUY_OPT(FAT_KERNEL) |
127 | kernel.Run(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst); |
128 | #else |
129 | for (int col = start_col; col < end_col; col += RhsLayout::kCols) { |
130 | int block_end_col = std::min(col + RhsLayout::kCols, end_col); |
131 | for (int row = start_row; row < end_row; row += LhsLayout::kCols) { |
132 | int block_end_row = std::min(row + LhsLayout::kCols, end_row); |
133 | kernel.Run(lhs, rhs, mul_params, row, col, block_end_row, block_end_col, |
134 | dst); |
135 | } |
136 | } |
137 | #endif |
138 | } |
139 | }; |
140 | |
141 | template <Path ThePath> |
142 | struct StandardCppKernelLayout {}; |
143 | |
144 | template <> |
145 | struct StandardCppKernelLayout<Path::kStandardCpp> { |
146 | using Lhs = FixedKernelLayout<Order::kColMajor, 1, 1>; |
147 | using Rhs = FixedKernelLayout<Order::kColMajor, 1, 1>; |
148 | }; |
149 | |
150 | // A variant exercising RowMajor square blocks |
151 | template <> |
152 | struct StandardCppKernelLayout<Path::kInternalStandardCppVariant1> { |
153 | using Lhs = FixedKernelLayout<Order::kRowMajor, 4, 4>; |
154 | using Rhs = FixedKernelLayout<Order::kRowMajor, 4, 4>; |
155 | }; |
156 | |
157 | // A variant with a rectangular layout: 4x8 |
158 | template <> |
159 | struct StandardCppKernelLayout<Path::kInternalStandardCppVariant2> { |
160 | using Lhs = FixedKernelLayout<Order::kColMajor, 1, 4>; |
161 | using Rhs = FixedKernelLayout<Order::kColMajor, 1, 8>; |
162 | }; |
163 | |
164 | // A variant with different block orders in LHS vs RHS. |
165 | template <> |
166 | struct StandardCppKernelLayout<Path::kInternalStandardCppVariant3> { |
167 | using Lhs = FixedKernelLayout<Order::kColMajor, 2, 16>; |
168 | using Rhs = FixedKernelLayout<Order::kRowMajor, 2, 8>; |
169 | }; |
170 | |
171 | // General implementation of the Kernel template, overridden by template |
172 | // specializations for specific SIMD code paths. This general implementation |
173 | // covers Path::kStandardCpp and its internal test-only variants. |
174 | template <Path ThePath, typename LhsScalar, typename RhsScalar, |
175 | typename AccumScalar, typename DstScalar> |
176 | struct Kernel { |
177 | // Each Kernel specialization defines kPath as the ground-truth path that it |
178 | // implements. This is used in assertions. As we support fallbacks between |
179 | // paths (see RUY_INHERIT_KERNEL), Unless a specialization for a specific set |
180 | // of template parameters was defined, it is normal for template |
181 | // instantiations of the form Kernel<SomePath, ...> to have kPath!=SomePath. |
182 | // Assertions that kPath==SomePath are used in places where we know that we |
183 | // should be using a template specialization for a specific path rather than a |
184 | // fallback. |
185 | static constexpr Path kPath = ThePath; |
186 | using MulParamsType = MulParams<AccumScalar, DstScalar>; |
187 | using LhsLayout = typename StandardCppKernelLayout<ThePath>::Lhs; |
188 | using RhsLayout = typename StandardCppKernelLayout<ThePath>::Rhs; |
189 | explicit Kernel(Tuning) {} |
190 | void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs, |
191 | const MulParamsType& mul_params, int start_row, int start_col, |
192 | int end_row, int end_col, Mat<DstScalar>* dst) const { |
193 | // See the comment in RunKernelTyped. end_row may be larger than |
194 | // dst->layout.rows. It's the responsibility of the kernel to avoid |
195 | // overrunning dst boundaries, which we do here by computing |
196 | // clamped_end_row. |
197 | int clamped_end_row = std::min(end_row, dst->layout.rows); |
198 | int clamped_end_col = std::min(end_col, dst->layout.cols); |
199 | RUY_DCHECK_LE(0, start_row); |
200 | RUY_DCHECK_LE(start_row, clamped_end_row); |
201 | RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); |
202 | RUY_DCHECK_LE(clamped_end_row, end_row); |
203 | RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); |
204 | RUY_DCHECK_LE(0, start_col); |
205 | RUY_DCHECK_LE(start_col, clamped_end_col); |
206 | RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); |
207 | RUY_DCHECK_LE(clamped_end_col, end_col); |
208 | RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); |
209 | profiler::ScopeLabel label("Kernel (Standard Cpp)" ); |
210 | const int depth = lhs.layout.rows; |
211 | for (int i = start_row; i < clamped_end_row; i++) { |
212 | for (int j = start_col; j < clamped_end_col; j++) { |
213 | AccumScalar accum = 0; |
214 | for (int k = 0; k < depth; k++) { |
215 | AccumScalar lhs_val = Element(lhs, k, i); |
216 | AccumScalar rhs_val = Element(rhs, k, j); |
217 | accum += lhs_val * rhs_val; |
218 | } |
219 | int channel = |
220 | mul_params.channel_dimension() == ChannelDimension::kRow ? i : j; |
221 | if (mul_params.bias()) { |
222 | accum += mul_params.bias()[channel]; |
223 | } |
224 | if (lhs.zero_point) { |
225 | accum -= lhs.zero_point * rhs.sums[j]; |
226 | } |
227 | if (rhs.zero_point) { |
228 | accum -= rhs.zero_point * lhs.sums[i]; |
229 | } |
230 | if (lhs.zero_point && rhs.zero_point) { |
231 | accum += lhs.zero_point * rhs.zero_point * depth; |
232 | } |
233 | ApplyMultiplier(mul_params, channel, &accum); |
234 | accum += dst->zero_point; |
235 | accum = std::min<AccumScalar>(accum, mul_params.clamp_max()); |
236 | accum = std::max<AccumScalar>(accum, mul_params.clamp_min()); |
237 | *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum); |
238 | } |
239 | } |
240 | } |
241 | }; |
242 | |
243 | } // namespace ruy |
244 | |
245 | #endif // RUY_RUY_KERNEL_H_ |
246 | |