1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_KERNEL_H_
17#define RUY_RUY_KERNEL_H_
18
19#include "ruy/kernel_common.h"
20#include "ruy/mul_params.h"
21#include "ruy/platform.h"
22#include "ruy/trace.h"
23
24// IWYU pragma: begin_exports
25#if RUY_PLATFORM_NEON
26#include "ruy/kernel_arm.h"
27#elif RUY_PLATFORM_X86
28#include "ruy/kernel_x86.h"
29#endif
30// IWYU pragma: end_exports
31
32namespace ruy {
33
34// KernelArgs is a helper to access the template parameter values from a Kernel
35// template instantiation.
36template <typename KernelType>
37struct KernelArgs {};
38
39template <Path tPath, typename tLhsScalar, typename tRhsScalar,
40 typename tAccumScalar, typename tDstScalar>
41struct KernelArgs<
42 Kernel<tPath, tLhsScalar, tRhsScalar, tAccumScalar, tDstScalar>> {
43 static constexpr Path kPath = tPath;
44 using LhsScalar = tLhsScalar;
45 using RhsScalar = tRhsScalar;
46 using AccumScalar = tAccumScalar;
47 using DstScalar = tDstScalar;
48};
49
50// RunKernel::Run() is the only place that directly invokes Kernel::Run().
51// It performs the types un-erasure, and factoring all Kernel::Run() calls
52// through this function also gives a single place where to conditionally
53// implement RUY_OPT(FAT_KERNEL). This should be a function but is a class to
54// hide and share some boilerplate (see the member types, and the RunTyped
55// method also using them).
56template <typename KernelType>
57class RunKernel final {
58 public:
59 static void Run(Tuning tuning, const SidePair<PEMat>& src,
60 const void* mul_params, const SidePair<int>& start,
61 const SidePair<int>& end, EMat* dst) {
62 RUY_TRACE_SCOPE_NAME("RunKernel");
63 const auto& unerased_lhs = UneraseType<LhsScalar>(src[Side::kLhs]);
64 const auto& unerased_rhs = UneraseType<RhsScalar>(src[Side::kRhs]);
65 auto unerased_dst = UneraseType<DstScalar>(*dst);
66 RUY_TRACE_INFO(RUN_KERNEL);
67 RunTyped(tuning, unerased_lhs, unerased_rhs,
68 *static_cast<const MulParamsType*>(mul_params), start, end,
69 &unerased_dst);
70 }
71
72 private:
73 using Args = KernelArgs<KernelType>;
74 using LhsScalar = typename Args::LhsScalar;
75 using RhsScalar = typename Args::RhsScalar;
76 using AccumScalar = typename Args::AccumScalar;
77 using DstScalar = typename Args::DstScalar;
78 using MulParamsType = MulParams<AccumScalar, DstScalar>;
79 static void RunTyped(Tuning tuning, const PMat<LhsScalar>& lhs,
80 const PMat<RhsScalar>& rhs,
81 const MulParamsType& mul_params,
82 const SidePair<int>& start, const SidePair<int>& end,
83 Mat<DstScalar>* dst) {
84 const int start_row = start[Side::kLhs];
85 const int start_col = start[Side::kRhs];
86 const int end_row = end[Side::kLhs];
87 const int end_col = end[Side::kRhs];
88 KernelType kernel(tuning);
89 using LhsLayout = typename KernelType::LhsLayout;
90 using RhsLayout = typename KernelType::RhsLayout;
91 // This is a good place to validate kernel layouts. The Kernel class
92 // template itself isn't a good place to do that because it has
93 // specializations.
94 // The kRows of both sides have to match: in TrMul, kRows is the depth
95 // dimension, on which LHS and RHS have to agree for the matrix
96 // multiplication to be defined at all, so requiring the corresponding
97 // dimension of the kernel layouts to also match is reasonable. If it didn't
98 // match, then the packed matrices could have mismatching depth dimensions
99 // even with the source matrices agreeing.
100 static_assert(LhsLayout::kRows == RhsLayout::kRows, "");
101 // The kernel layouts have to be power-of-two. This simplifies BlockMap
102 // logic considerably. This also avoids leaking fine performance
103 // optimization details up the stack. For instance, if one of the dimensions
104 // were 6, then users might notice that optimal performance is achieved with
105 // matrix dimensions that are multiples of 6, and might start contorting
106 // their own application code to match that requirement, in a way that would
107 // not be future-proof.
108 static_assert(is_pot(LhsLayout::kRows), "");
109 static_assert(is_pot(LhsLayout::kCols), "");
110 static_assert(is_pot(RhsLayout::kRows), "");
111 static_assert(is_pot(RhsLayout::kCols), "");
112 // end_row and end_col may be larger than dst dimensions.
113 // that is because kernels write directly to the destination matrix, whose
114 // dimensions may not be a multiple of the kernel dimensions, and we try to
115 // keep this annoyance localized as an implementation detail in kernels,
116 // by allowing to pass rounded-up values down as far as possible.
117 // These assertions encode the contract.
118 RUY_DCHECK_LE(0, start_row);
119 RUY_DCHECK_LE(start_row, end_row);
120 RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
121 RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
122 RUY_DCHECK_LE(0, start_col);
123 RUY_DCHECK_LE(start_col, end_col);
124 RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
125 RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
126#if RUY_OPT(FAT_KERNEL)
127 kernel.Run(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst);
128#else
129 for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
130 int block_end_col = std::min(col + RhsLayout::kCols, end_col);
131 for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
132 int block_end_row = std::min(row + LhsLayout::kCols, end_row);
133 kernel.Run(lhs, rhs, mul_params, row, col, block_end_row, block_end_col,
134 dst);
135 }
136 }
137#endif
138 }
139};
140
141template <Path ThePath>
142struct StandardCppKernelLayout {};
143
144template <>
145struct StandardCppKernelLayout<Path::kStandardCpp> {
146 using Lhs = FixedKernelLayout<Order::kColMajor, 1, 1>;
147 using Rhs = FixedKernelLayout<Order::kColMajor, 1, 1>;
148};
149
150// A variant exercising RowMajor square blocks
151template <>
152struct StandardCppKernelLayout<Path::kInternalStandardCppVariant1> {
153 using Lhs = FixedKernelLayout<Order::kRowMajor, 4, 4>;
154 using Rhs = FixedKernelLayout<Order::kRowMajor, 4, 4>;
155};
156
157// A variant with a rectangular layout: 4x8
158template <>
159struct StandardCppKernelLayout<Path::kInternalStandardCppVariant2> {
160 using Lhs = FixedKernelLayout<Order::kColMajor, 1, 4>;
161 using Rhs = FixedKernelLayout<Order::kColMajor, 1, 8>;
162};
163
164// A variant with different block orders in LHS vs RHS.
165template <>
166struct StandardCppKernelLayout<Path::kInternalStandardCppVariant3> {
167 using Lhs = FixedKernelLayout<Order::kColMajor, 2, 16>;
168 using Rhs = FixedKernelLayout<Order::kRowMajor, 2, 8>;
169};
170
171// General implementation of the Kernel template, overridden by template
172// specializations for specific SIMD code paths. This general implementation
173// covers Path::kStandardCpp and its internal test-only variants.
174template <Path ThePath, typename LhsScalar, typename RhsScalar,
175 typename AccumScalar, typename DstScalar>
176struct Kernel {
177 // Each Kernel specialization defines kPath as the ground-truth path that it
178 // implements. This is used in assertions. As we support fallbacks between
179 // paths (see RUY_INHERIT_KERNEL), Unless a specialization for a specific set
180 // of template parameters was defined, it is normal for template
181 // instantiations of the form Kernel<SomePath, ...> to have kPath!=SomePath.
182 // Assertions that kPath==SomePath are used in places where we know that we
183 // should be using a template specialization for a specific path rather than a
184 // fallback.
185 static constexpr Path kPath = ThePath;
186 using MulParamsType = MulParams<AccumScalar, DstScalar>;
187 using LhsLayout = typename StandardCppKernelLayout<ThePath>::Lhs;
188 using RhsLayout = typename StandardCppKernelLayout<ThePath>::Rhs;
189 explicit Kernel(Tuning) {}
190 void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs,
191 const MulParamsType& mul_params, int start_row, int start_col,
192 int end_row, int end_col, Mat<DstScalar>* dst) const {
193 // See the comment in RunKernelTyped. end_row may be larger than
194 // dst->layout.rows. It's the responsibility of the kernel to avoid
195 // overrunning dst boundaries, which we do here by computing
196 // clamped_end_row.
197 int clamped_end_row = std::min(end_row, dst->layout.rows);
198 int clamped_end_col = std::min(end_col, dst->layout.cols);
199 RUY_DCHECK_LE(0, start_row);
200 RUY_DCHECK_LE(start_row, clamped_end_row);
201 RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
202 RUY_DCHECK_LE(clamped_end_row, end_row);
203 RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
204 RUY_DCHECK_LE(0, start_col);
205 RUY_DCHECK_LE(start_col, clamped_end_col);
206 RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
207 RUY_DCHECK_LE(clamped_end_col, end_col);
208 RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
209 profiler::ScopeLabel label("Kernel (Standard Cpp)");
210 const int depth = lhs.layout.rows;
211 for (int i = start_row; i < clamped_end_row; i++) {
212 for (int j = start_col; j < clamped_end_col; j++) {
213 AccumScalar accum = 0;
214 for (int k = 0; k < depth; k++) {
215 AccumScalar lhs_val = Element(lhs, k, i);
216 AccumScalar rhs_val = Element(rhs, k, j);
217 accum += lhs_val * rhs_val;
218 }
219 int channel =
220 mul_params.channel_dimension() == ChannelDimension::kRow ? i : j;
221 if (mul_params.bias()) {
222 accum += mul_params.bias()[channel];
223 }
224 if (lhs.zero_point) {
225 accum -= lhs.zero_point * rhs.sums[j];
226 }
227 if (rhs.zero_point) {
228 accum -= rhs.zero_point * lhs.sums[i];
229 }
230 if (lhs.zero_point && rhs.zero_point) {
231 accum += lhs.zero_point * rhs.zero_point * depth;
232 }
233 ApplyMultiplier(mul_params, channel, &accum);
234 accum += dst->zero_point;
235 accum = std::min<AccumScalar>(accum, mul_params.clamp_max());
236 accum = std::max<AccumScalar>(accum, mul_params.clamp_min());
237 *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
238 }
239 }
240 }
241};
242
243} // namespace ruy
244
245#endif // RUY_RUY_KERNEL_H_
246