1/* Copyright 2020 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required_capacity by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implementation of CreateTrMulParams, see function comment.
17
18#ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_
19#define RUY_RUY_CREATE_TRMUL_PARAMS_H_
20
21#include <cstdint>
22#include <cstring>
23#include <type_traits>
24
25#include "ruy/allocator.h"
26#include "ruy/ctx.h"
27#include "ruy/kernel.h"
28#include "ruy/mat.h"
29#include "ruy/mul_params.h"
30#include "ruy/pack.h"
31#include "ruy/path.h"
32#include "ruy/performance_advisory.h"
33#include "ruy/trace.h"
34#include "ruy/trmul_params.h"
35
36namespace ruy {
37// While the only entry point to this file is CreateTrMulParams, its templatized
38// nature requires putting more code in this header than we would like. This
39// internal implementation code is enclosed in namespace 'detail'.
40namespace detail {
41
42inline void CreatePackedLayout(const MatLayout& src,
43 const KernelLayout& kernel_layout,
44 PMatLayout* packed_layout) {
45 // Packed matrices are always column-major, because in TrMul that is always
46 // the dimension of traversal of the kernel's inner loop.
47 packed_layout->order = Order::kColMajor;
48 packed_layout->rows = round_up_pot(src.rows, kernel_layout.rows);
49 packed_layout->cols = round_up_pot(src.cols, kernel_layout.cols);
50 packed_layout->stride = packed_layout->rows;
51 packed_layout->kernel = kernel_layout;
52}
53
54template <typename Scalar, typename PackedScalar>
55void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
56 TrMulParams* params) {
57 // Ruy always uses 32-bit signed accumulators for quantized
58 // matrix multiplication, so we would like to always use std::int32_t
59 // unconditionally for SumsType.
60 // However, for floating point types, we still need a reasonable type here to
61 // avoid tripping assertions elsewhere in the code.
62 using SumsType =
63 typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
64 std::int32_t>::type;
65
66 const EMat& src = params->src[side];
67 PEMat* packed_matrix = &params->packed_matrix[side];
68 packed_matrix->data_type = Type::Create<PackedScalar>();
69 packed_matrix->sums_type = Type::Create<SumsType>();
70 CreatePackedLayout(src.layout, kernel_layout, &packed_matrix->layout);
71 packed_matrix->zero_point = Pack<PackedScalar, Scalar>(src.zero_point);
72}
73
74template <typename KernelType>
75struct CheckKernelPathImpl {
76 static void Run(Path) {
77 // Do nothing.
78 // Path fallbacks are normal in general (see RUY_INHERIT_KERNEL).
79 // That is to say that one may instantiate ruy::Mul with a weird combination
80 // of types, such as LhsScalar==float and RhsScalar==double, and have it
81 // work by silently falling back to Path::kStandardCpp. Only in specific
82 // cases do we have dedicated kernels overriding that fallback, and that is
83 // what partial specializations of this template will check.
84 }
85};
86
87#if RUY_DCHECK_IS_ENABLED
88template <Path ThePath, typename SrcScalar, typename AccumScalar,
89 typename DstScalar>
90struct CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar,
91 MulParams<AccumScalar, DstScalar>>>
92 final {
93 using KernelType = Kernel<ThePath, SrcScalar, SrcScalar, DstScalar,
94 MulParams<AccumScalar, DstScalar>>;
95 static void Run(Path expected_path) {
96 // We want to assert that we are using a dedicated Kernel specialization and
97 // not a fallback when we know we are in a case where such a kernel
98 // specialization exists. At the moment in the current state of ruy's
99 // architecture support for ARM and x86, that is when LhsScalar==RhsScalar
100 // (already implied in this partial specialization) and when that type is
101 // either float, int8, or uint8. Indeed, we have kernels supporting float
102 // and int8, and we have the packing code converting uint8 to int8 (see
103 // PackedTypeImpl).
104 static constexpr bool kSrcScalarTypeSupportsFastKernels =
105 std::is_same<SrcScalar, float>::value ||
106 std::is_same<SrcScalar, std::int8_t>::value ||
107 std::is_same<SrcScalar, std::uint8_t>::value;
108 if (kSrcScalarTypeSupportsFastKernels) {
109 RUY_DCHECK_EQ(expected_path, KernelType::kPath);
110 }
111 }
112};
113#endif
114
115template <typename KernelType>
116void CheckKernelPath(Path expected_path) {
117 CheckKernelPathImpl<KernelType>::Run(expected_path);
118}
119
120template <Path ThePath, typename LhsScalar, typename RhsScalar,
121 typename AccumScalar, typename DstScalar>
122void PopulateTrMulParams(TrMulParams* params) {
123 RUY_TRACE_SCOPE;
124 using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
125 using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
126 using Kernel =
127 Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, AccumScalar, DstScalar>;
128 using LhsKernelLayout = typename Kernel::LhsLayout;
129 using RhsKernelLayout = typename Kernel::RhsLayout;
130
131 params->path = ThePath;
132
133 CreatePackedMatrix<LhsScalar, PackedLhsScalar>(
134 Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params);
135 CreatePackedMatrix<RhsScalar, PackedRhsScalar>(
136 Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params);
137 params->run_pack[Side::kLhs] =
138 &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>;
139 params->run_pack[Side::kRhs] =
140 &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>;
141 params->run_kernel = &RunKernel<Kernel>::Run;
142 CheckKernelPath<Kernel>(ThePath);
143 RUY_TRACE_INFO(POPULATE_TRMUL_PARAMS);
144}
145
146// PopulateTrMulParamsAllCompiledPaths calls into one of multiple
147// instantiations of PopulateTrMulParams. For each bit that is set in
148// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path
149// corresponding to that single bit. The call to PopulateTrMulParams is
150// guarded by a runtime check that it is in fact the dynamically selected path.
151//
152// PopulateTrMulParamsAllCompiledPaths is implemented with template
153// metaprogramming by mutual recursion between PathSearchCountdown and
154// PathSearchCompiledPaths.
155//
156// PopulateTrMulParamsAllCompiledPaths is logically implementing the following
157// computation:
158//
159// template <Path CompiledPaths>
160// void PopulateTrMulParamsAllCompiledPaths(Path the_path,
161// TrMulParams* params) {
162// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1]
163// Path current_path = static_cast<Path>(1 << bit);
164// if ((CompiledPaths & current_path) != Path::kNone) { // [2]
165// if (current_path == the_path) { // [3]
166// PopulateTrMulParams<current_path, ...>(the_path, params);
167// return;
168// }
169// }
170// }
171// }
172//
173//
174//
175// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is
176// done in the recursion of PathSearchOnlyCompiledPaths.
177// [2] - Done by PathSearchOnlyCompiledPaths's partial template
178// specialization on InCompiledPaths. This is the check which necessitates
179// doing the whole computation at C++ compile time.
180// [3] - Done by the `if` in the main definition of
181// PathSearchOnlyCompiledPaths.
182//
183// The template metaprogramming is necessary because:
184// - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++
185// compile-time constant.
186// - PopulateTrMulParamsAllCompiledPaths must not instantiate
187// inner loops for paths that are not in CompiledPaths, since that can result in
188// bogus instantiations which cause a compile time failure.
189template <Path CompiledPaths, int BitNumber, typename LhsScalar,
190 typename RhsScalar, typename AccumScalar, typename DstScalar>
191struct PathSearchCountdown;
192
193template <Path CompiledPaths, bool InCompiledPaths, int BitNumber,
194 typename LhsScalar, typename RhsScalar, typename AccumScalar,
195 typename DstScalar>
196struct PathSearchOnlyCompiledPaths {
197 static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
198 static void Search(Path the_path, TrMulParams* params) {
199 if (kCurrentPath == the_path) {
200 PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, AccumScalar,
201 DstScalar>(params);
202 return;
203 }
204 PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
205 AccumScalar, DstScalar>::Search(the_path, params);
206 }
207};
208
209// Skip this iteration if CompiledPaths doesn't contain the specified path.
210template <Path CompiledPaths, int BitNumber, typename LhsScalar,
211 typename RhsScalar, typename AccumScalar, typename DstScalar>
212struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar,
213 RhsScalar, AccumScalar, DstScalar> {
214 static void Search(Path the_path, TrMulParams* params) {
215 PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
216 AccumScalar, DstScalar>::Search(the_path, params);
217 }
218};
219
220template <Path CompiledPaths, int BitNumber, typename LhsScalar,
221 typename RhsScalar, typename AccumScalar, typename DstScalar>
222struct PathSearchCountdown {
223 static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
224 static void Search(Path the_path, TrMulParams* params) {
225 PathSearchOnlyCompiledPaths<
226 CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber,
227 LhsScalar, RhsScalar, AccumScalar, DstScalar>::Search(the_path, params);
228 }
229};
230
231// Termination of the countdown. If the counter reaches -1, then we haven't
232// found the specified path.
233template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
234 typename AccumScalar, typename DstScalar>
235struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar,
236 DstScalar> {
237 static void Search(Path, TrMulParams*) { RUY_DCHECK(false); }
238};
239
240template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
241 typename AccumScalar, typename DstScalar>
242void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {
243 RUY_TRACE_SCOPE;
244 return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar,
245 RhsScalar, AccumScalar,
246 DstScalar>::Search(the_path, params);
247}
248
249template <typename AccumScalar, typename DstScalar>
250void AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
251 const MulParams<AccumScalar, DstScalar>& mul_params, int user_size,
252 int user_capacity) {
253#if RUY_DCHECK_IS_ENABLED
254 if (mul_params.bias()) {
255 for (int i = user_size; i < user_capacity; i++) {
256 RUY_DCHECK_EQ(mul_params.bias()[i], 0);
257 }
258 }
259 if (mul_params.multiplier_fixedpoint_perchannel()) {
260 for (int i = user_size; i < user_capacity; i++) {
261 RUY_DCHECK_EQ(mul_params.multiplier_fixedpoint_perchannel()[i], 0);
262 }
263 }
264 if (mul_params.multiplier_exponent_perchannel()) {
265 for (int i = user_size; i < user_capacity; i++) {
266 RUY_DCHECK_EQ(mul_params.multiplier_exponent_perchannel()[i], 0);
267 }
268 }
269#else
270 (void)mul_params;
271 (void)user_size;
272 (void)user_capacity;
273#endif
274}
275
276template <typename AccumScalar, typename DstScalar,
277 bool HaveQuantizedMultipliers =
278 std::is_same<AccumScalar, std::int32_t>::value &&
279 !std::is_same<DstScalar, std::int32_t>::value>
280struct EnsurePerChannelBuffersLargeEnoughImpl {
281 static void Run(const TrMulParams& params, Allocator* allocator,
282 MulParams<AccumScalar, DstScalar>* mul_params) {
283 const Side channel_side =
284 mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
285 : Side::kRhs;
286 const int required_capacity =
287 params.packed_matrix[channel_side].layout.cols;
288 const int user_size = params.src[channel_side].layout.cols;
289 const int user_capacity = round_up_pot(
290 user_size, mul_params->perchannel_buffers_capacity_rounding());
291 // We should have already checked earlier for the case where
292 // user_capacity >= required_capacity.
293 RUY_DCHECK_GT(required_capacity, user_capacity);
294 if (mul_params->bias()) {
295 AccumScalar* new_data =
296 allocator->Allocate<AccumScalar>(required_capacity);
297 std::memcpy(new_data, mul_params->bias(),
298 user_size * sizeof(AccumScalar));
299 std::memset(new_data + user_size, 0,
300 (required_capacity - user_size) * sizeof(AccumScalar));
301 mul_params->set_bias(new_data);
302 }
303 if (mul_params->multiplier_fixedpoint_perchannel()) {
304 AccumScalar* new_data =
305 allocator->Allocate<AccumScalar>(required_capacity);
306 std::memcpy(new_data, mul_params->multiplier_fixedpoint_perchannel(),
307 user_size * sizeof(AccumScalar));
308 std::memset(new_data + user_size, 0,
309 (required_capacity - user_size) * sizeof(AccumScalar));
310 mul_params->set_multiplier_fixedpoint_perchannel(new_data);
311 }
312 if (mul_params->multiplier_exponent_perchannel()) {
313 int* new_data = allocator->Allocate<int>(required_capacity);
314 std::memcpy(new_data, mul_params->multiplier_exponent_perchannel(),
315 user_size * sizeof(int));
316 std::memset(new_data + user_size, 0,
317 (required_capacity - user_size) * sizeof(int));
318 mul_params->set_multiplier_exponent_perchannel(new_data);
319 }
320 }
321};
322
323template <typename AccumScalar, typename DstScalar>
324struct EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar, false> {
325 static void Run(const TrMulParams& params, Allocator* allocator,
326 MulParams<AccumScalar, DstScalar>* mul_params) {
327 const Side channel_side =
328 mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
329 : Side::kRhs;
330 const int required_capacity =
331 params.packed_matrix[channel_side].layout.cols;
332 const int user_size = params.src[channel_side].layout.cols;
333 const int user_capacity = round_up_pot(
334 user_size, mul_params->perchannel_buffers_capacity_rounding());
335 // We should have already checked earlier for the case where
336 // user_capacity >= required_capacity.
337 RUY_DCHECK_GT(required_capacity, user_capacity);
338 if (mul_params->bias()) {
339 AccumScalar* new_data =
340 allocator->Allocate<AccumScalar>(required_capacity);
341 std::memcpy(new_data, mul_params->bias(),
342 user_size * sizeof(AccumScalar));
343 std::memset(new_data + user_size, 0,
344 (required_capacity - user_size) * sizeof(AccumScalar));
345 mul_params->set_bias(new_data);
346 }
347 }
348};
349
350template <typename AccumScalar, typename DstScalar>
351void EnsurePerChannelBuffersLargeEnough(
352 const TrMulParams& params, Ctx* ctx,
353 MulParams<AccumScalar, DstScalar>* mul_params) {
354 // Early exit in the common case where the packed matrix size matches the
355 // number of channels (as opposed to having been rounded up to a slightly
356 // larger value).
357 const Side channel_side =
358 mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs
359 : Side::kRhs;
360 const int required_capacity = params.packed_matrix[channel_side].layout.cols;
361 const int user_size = params.src[channel_side].layout.cols;
362 const int user_capacity = round_up_pot(
363 user_size, mul_params->perchannel_buffers_capacity_rounding());
364 AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
365 *mul_params, user_size, user_capacity);
366 if (required_capacity <= user_capacity) {
367 return;
368 }
369 ctx->set_performance_advisory(
370 PerformanceAdvisory::kReallocatedPerChannelBuffer);
371 EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar>::Run(
372 params, ctx->GetMainAllocator(), mul_params);
373}
374
375// Ensures that `params->mul_params_bytes` contains MulParams data that's ready
376// to be consumed by the kernel. As a first-order approximation, that is simply
377// copying the user-provided `mul_params`, however there are a few changes.
378//
379// 1. The specified `channel_dimension` value overrides the channel_dimension
380// member in `mul_params`. The reason why `channel_dimension` is being
381// special-cased among MulParams members is that we will need to transpose
382// MulParams, and that consists just in toggling channel_dimension.
383// 2. Per-channel buffers may be reallocated, see
384// EnsurePerChannelBuffersLargeEnough.
385template <typename AccumScalar, typename DstScalar>
386void FinalizeMulParams(const MulParams<AccumScalar, DstScalar>& mul_params,
387 ChannelDimension channel_dimension, Ctx* ctx,
388 TrMulParams* params) {
389 using MulParamsType = MulParams<AccumScalar, DstScalar>;
390 static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, "");
391 static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, "");
392 static_assert(std::is_trivially_copyable<MulParamsType>::value, "");
393 auto* dst_mul_params =
394 reinterpret_cast<MulParamsType*>(params->mul_params_bytes);
395 std::memcpy(dst_mul_params, &mul_params, sizeof(MulParamsType));
396 dst_mul_params->set_channel_dimension(channel_dimension);
397 EnsurePerChannelBuffersLargeEnough(*params, ctx, dst_mul_params);
398}
399
400// In this function, the `channel_dimension` parameter overrides the value
401// of the channel_dimension member in the `mul_params` parameter. See the
402// FinalizeMulParams comment.
403template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
404 typename AccumScalar, typename DstScalar>
405void CreateTrMulParamsAssumingColMajorDst(
406 const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
407 const Mat<DstScalar>& dst,
408 const MulParams<AccumScalar, DstScalar>& mul_params,
409 ChannelDimension channel_dimension, Ctx* ctx, TrMulParams* params) {
410 RUY_TRACE_SCOPE;
411 RUY_DCHECK(IsColMajor(dst.layout));
412
413 // Fill in the fields we already know.
414 params->src[Side::kLhs] = EraseType(lhs);
415 params->src[Side::kRhs] = EraseType(rhs);
416 params->dst = EraseType(dst);
417
418 // Determine which exact Path we're going to take in this Mul call.
419 // This is cheap because it's cached in `ctx`. In user scenarios this always
420 // evaluates to the same value on a given machine with given `CompiledPaths`,
421 // but could be invalidated by a call to Ctx::SetRuntimeEnabledPaths(), which
422 // might be exposed publicly in Context in the future.
423 const Path the_path = ctx->SelectPath(CompiledPaths);
424
425 RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST);
426
427 // If we ever need again to fall back to Path::kStandardCpp, this is a good
428 // place to do it -- just pass Path::kStandardCpp as both the template and
429 // runtime parameters in this function call.
430 // In the past we did that here (as version control history remembers).
431 // A typical reason why we might need to resurrect that is if we implement
432 // a new Path (i.e. port to a new ISA) and need to subdivide that work into
433 // a series of incremental changes.
434 PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar,
435 AccumScalar, DstScalar>(the_path, params);
436
437 // This must be done last, as it depends on the specific choice of kernel.
438 // Specifically, the EnsurePerChannelBuffersLargeEnough part of this will read
439 // the packed matrix layouts that are written to `params` by the above
440 // PopulateTrMulParams* call.
441 FinalizeMulParams(mul_params, channel_dimension, ctx, params);
442}
443
444} // namespace detail
445
446inline ChannelDimension Transpose(ChannelDimension channel_dimension) {
447 return channel_dimension == ChannelDimension::kCol ? ChannelDimension::kRow
448 : ChannelDimension::kCol;
449}
450
451// CreateTrMulParams's output is a TrMulParams object that encodes
452// all of the input information required_capacity by the middle-end, that is,
453// the TrMul function.
454//
455// CreateTrMulParams performs the following tasks:
456// 1. Reduce to the case of column-major destination, by transposing the
457// whole problem as needed.
458// 2. Select the single code path to be taken, out of the set of paths
459// described by the `CompiledPaths` template parameter, based on the
460// runtime input parameter `the_path`.
461// 3. Perform type-erasure, converting templatized typed input parameters
462// to the un-typed data stored in TrMulParams.
463template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
464 typename AccumScalar, typename DstScalar>
465void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
466 const Mat<DstScalar>& dst,
467 const MulParams<AccumScalar, DstScalar>& mul_params,
468 Ctx* ctx, TrMulParams* params) {
469 RUY_TRACE_SCOPE;
470 ChannelDimension channel_dimension = mul_params.channel_dimension();
471 if (IsColMajor(dst.layout)) {
472 detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>(
473 lhs, rhs, dst, mul_params, channel_dimension, ctx, params);
474 } else {
475 RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_TRANSPOSING);
476 detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>(
477 rhs, lhs, Transpose(dst), mul_params, Transpose(channel_dimension), ctx,
478 params);
479 }
480}
481
482} // namespace ruy
483
484#endif // RUY_RUY_CREATE_TRMUL_PARAMS_H_
485