1 | /* Copyright 2020 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required_capacity by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implementation of CreateTrMulParams, see function comment. |
17 | |
18 | #ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_ |
19 | #define RUY_RUY_CREATE_TRMUL_PARAMS_H_ |
20 | |
21 | #include <cstdint> |
22 | #include <cstring> |
23 | #include <type_traits> |
24 | |
25 | #include "ruy/allocator.h" |
26 | #include "ruy/ctx.h" |
27 | #include "ruy/kernel.h" |
28 | #include "ruy/mat.h" |
29 | #include "ruy/mul_params.h" |
30 | #include "ruy/pack.h" |
31 | #include "ruy/path.h" |
32 | #include "ruy/performance_advisory.h" |
33 | #include "ruy/trace.h" |
34 | #include "ruy/trmul_params.h" |
35 | |
36 | namespace ruy { |
37 | // While the only entry point to this file is CreateTrMulParams, its templatized |
38 | // nature requires putting more code in this header than we would like. This |
39 | // internal implementation code is enclosed in namespace 'detail'. |
40 | namespace detail { |
41 | |
42 | inline void CreatePackedLayout(const MatLayout& src, |
43 | const KernelLayout& kernel_layout, |
44 | PMatLayout* packed_layout) { |
45 | // Packed matrices are always column-major, because in TrMul that is always |
46 | // the dimension of traversal of the kernel's inner loop. |
47 | packed_layout->order = Order::kColMajor; |
48 | packed_layout->rows = round_up_pot(src.rows, kernel_layout.rows); |
49 | packed_layout->cols = round_up_pot(src.cols, kernel_layout.cols); |
50 | packed_layout->stride = packed_layout->rows; |
51 | packed_layout->kernel = kernel_layout; |
52 | } |
53 | |
54 | template <typename Scalar, typename PackedScalar> |
55 | void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, |
56 | TrMulParams* params) { |
57 | // Ruy always uses 32-bit signed accumulators for quantized |
58 | // matrix multiplication, so we would like to always use std::int32_t |
59 | // unconditionally for SumsType. |
60 | // However, for floating point types, we still need a reasonable type here to |
61 | // avoid tripping assertions elsewhere in the code. |
62 | using SumsType = |
63 | typename std::conditional<std::is_floating_point<Scalar>::value, Scalar, |
64 | std::int32_t>::type; |
65 | |
66 | const EMat& src = params->src[side]; |
67 | PEMat* packed_matrix = ¶ms->packed_matrix[side]; |
68 | packed_matrix->data_type = Type::Create<PackedScalar>(); |
69 | packed_matrix->sums_type = Type::Create<SumsType>(); |
70 | CreatePackedLayout(src.layout, kernel_layout, &packed_matrix->layout); |
71 | packed_matrix->zero_point = Pack<PackedScalar, Scalar>(src.zero_point); |
72 | } |
73 | |
74 | template <typename KernelType> |
75 | struct CheckKernelPathImpl { |
76 | static void Run(Path) { |
77 | // Do nothing. |
78 | // Path fallbacks are normal in general (see RUY_INHERIT_KERNEL). |
79 | // That is to say that one may instantiate ruy::Mul with a weird combination |
80 | // of types, such as LhsScalar==float and RhsScalar==double, and have it |
81 | // work by silently falling back to Path::kStandardCpp. Only in specific |
82 | // cases do we have dedicated kernels overriding that fallback, and that is |
83 | // what partial specializations of this template will check. |
84 | } |
85 | }; |
86 | |
87 | #if RUY_DCHECK_IS_ENABLED |
88 | template <Path ThePath, typename SrcScalar, typename AccumScalar, |
89 | typename DstScalar> |
90 | struct CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, |
91 | MulParams<AccumScalar, DstScalar>>> |
92 | final { |
93 | using KernelType = Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, |
94 | MulParams<AccumScalar, DstScalar>>; |
95 | static void Run(Path expected_path) { |
96 | // We want to assert that we are using a dedicated Kernel specialization and |
97 | // not a fallback when we know we are in a case where such a kernel |
98 | // specialization exists. At the moment in the current state of ruy's |
99 | // architecture support for ARM and x86, that is when LhsScalar==RhsScalar |
100 | // (already implied in this partial specialization) and when that type is |
101 | // either float, int8, or uint8. Indeed, we have kernels supporting float |
102 | // and int8, and we have the packing code converting uint8 to int8 (see |
103 | // PackedTypeImpl). |
104 | static constexpr bool kSrcScalarTypeSupportsFastKernels = |
105 | std::is_same<SrcScalar, float>::value || |
106 | std::is_same<SrcScalar, std::int8_t>::value || |
107 | std::is_same<SrcScalar, std::uint8_t>::value; |
108 | if (kSrcScalarTypeSupportsFastKernels) { |
109 | RUY_DCHECK_EQ(expected_path, KernelType::kPath); |
110 | } |
111 | } |
112 | }; |
113 | #endif |
114 | |
115 | template <typename KernelType> |
116 | void CheckKernelPath(Path expected_path) { |
117 | CheckKernelPathImpl<KernelType>::Run(expected_path); |
118 | } |
119 | |
120 | template <Path ThePath, typename LhsScalar, typename RhsScalar, |
121 | typename AccumScalar, typename DstScalar> |
122 | void PopulateTrMulParams(TrMulParams* params) { |
123 | RUY_TRACE_SCOPE; |
124 | using PackedLhsScalar = PackedType<ThePath, LhsScalar>; |
125 | using PackedRhsScalar = PackedType<ThePath, RhsScalar>; |
126 | using Kernel = |
127 | Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, AccumScalar, DstScalar>; |
128 | using LhsKernelLayout = typename Kernel::LhsLayout; |
129 | using RhsKernelLayout = typename Kernel::RhsLayout; |
130 | |
131 | params->path = ThePath; |
132 | |
133 | CreatePackedMatrix<LhsScalar, PackedLhsScalar>( |
134 | Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params); |
135 | CreatePackedMatrix<RhsScalar, PackedRhsScalar>( |
136 | Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params); |
137 | params->run_pack[Side::kLhs] = |
138 | &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>; |
139 | params->run_pack[Side::kRhs] = |
140 | &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>; |
141 | params->run_kernel = &RunKernel<Kernel>::Run; |
142 | CheckKernelPath<Kernel>(ThePath); |
143 | RUY_TRACE_INFO(POPULATE_TRMUL_PARAMS); |
144 | } |
145 | |
146 | // PopulateTrMulParamsAllCompiledPaths calls into one of multiple |
147 | // instantiations of PopulateTrMulParams. For each bit that is set in |
148 | // CompiledPaths, it statically instantiates PopulateTrMulParams with a Path |
149 | // corresponding to that single bit. The call to PopulateTrMulParams is |
150 | // guarded by a runtime check that it is in fact the dynamically selected path. |
151 | // |
152 | // PopulateTrMulParamsAllCompiledPaths is implemented with template |
153 | // metaprogramming by mutual recursion between PathSearchCountdown and |
154 | // PathSearchCompiledPaths. |
155 | // |
156 | // PopulateTrMulParamsAllCompiledPaths is logically implementing the following |
157 | // computation: |
158 | // |
159 | // template <Path CompiledPaths> |
160 | // void PopulateTrMulParamsAllCompiledPaths(Path the_path, |
161 | // TrMulParams* params) { |
162 | // for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] |
163 | // Path current_path = static_cast<Path>(1 << bit); |
164 | // if ((CompiledPaths & current_path) != Path::kNone) { // [2] |
165 | // if (current_path == the_path) { // [3] |
166 | // PopulateTrMulParams<current_path, ...>(the_path, params); |
167 | // return; |
168 | // } |
169 | // } |
170 | // } |
171 | // } |
172 | // |
173 | // |
174 | // |
175 | // [1] - Done by the main definition of PathSearchCountdown. The `bit--` is |
176 | // done in the recursion of PathSearchOnlyCompiledPaths. |
177 | // [2] - Done by PathSearchOnlyCompiledPaths's partial template |
178 | // specialization on InCompiledPaths. This is the check which necessitates |
179 | // doing the whole computation at C++ compile time. |
180 | // [3] - Done by the `if` in the main definition of |
181 | // PathSearchOnlyCompiledPaths. |
182 | // |
183 | // The template metaprogramming is necessary because: |
184 | // - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++ |
185 | // compile-time constant. |
186 | // - PopulateTrMulParamsAllCompiledPaths must not instantiate |
187 | // inner loops for paths that are not in CompiledPaths, since that can result in |
188 | // bogus instantiations which cause a compile time failure. |
189 | template <Path CompiledPaths, int BitNumber, typename LhsScalar, |
190 | typename RhsScalar, typename AccumScalar, typename DstScalar> |
191 | struct PathSearchCountdown; |
192 | |
193 | template <Path CompiledPaths, bool InCompiledPaths, int BitNumber, |
194 | typename LhsScalar, typename RhsScalar, typename AccumScalar, |
195 | typename DstScalar> |
196 | struct PathSearchOnlyCompiledPaths { |
197 | static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber); |
198 | static void Search(Path the_path, TrMulParams* params) { |
199 | if (kCurrentPath == the_path) { |
200 | PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, AccumScalar, |
201 | DstScalar>(params); |
202 | return; |
203 | } |
204 | PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar, |
205 | AccumScalar, DstScalar>::Search(the_path, params); |
206 | } |
207 | }; |
208 | |
209 | // Skip this iteration if CompiledPaths doesn't contain the specified path. |
210 | template <Path CompiledPaths, int BitNumber, typename LhsScalar, |
211 | typename RhsScalar, typename AccumScalar, typename DstScalar> |
212 | struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar, |
213 | RhsScalar, AccumScalar, DstScalar> { |
214 | static void Search(Path the_path, TrMulParams* params) { |
215 | PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar, |
216 | AccumScalar, DstScalar>::Search(the_path, params); |
217 | } |
218 | }; |
219 | |
220 | template <Path CompiledPaths, int BitNumber, typename LhsScalar, |
221 | typename RhsScalar, typename AccumScalar, typename DstScalar> |
222 | struct PathSearchCountdown { |
223 | static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber); |
224 | static void Search(Path the_path, TrMulParams* params) { |
225 | PathSearchOnlyCompiledPaths< |
226 | CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, |
227 | LhsScalar, RhsScalar, AccumScalar, DstScalar>::Search(the_path, params); |
228 | } |
229 | }; |
230 | |
231 | // Termination of the countdown. If the counter reaches -1, then we haven't |
232 | // found the specified path. |
233 | template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
234 | typename AccumScalar, typename DstScalar> |
235 | struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar, |
236 | DstScalar> { |
237 | static void Search(Path, TrMulParams*) { RUY_DCHECK(false); } |
238 | }; |
239 | |
240 | template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
241 | typename AccumScalar, typename DstScalar> |
242 | void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { |
243 | RUY_TRACE_SCOPE; |
244 | return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar, |
245 | RhsScalar, AccumScalar, |
246 | DstScalar>::Search(the_path, params); |
247 | } |
248 | |
249 | template <typename AccumScalar, typename DstScalar> |
250 | void ( |
251 | const MulParams<AccumScalar, DstScalar>& mul_params, int user_size, |
252 | int user_capacity) { |
253 | #if RUY_DCHECK_IS_ENABLED |
254 | if (mul_params.bias()) { |
255 | for (int i = user_size; i < user_capacity; i++) { |
256 | RUY_DCHECK_EQ(mul_params.bias()[i], 0); |
257 | } |
258 | } |
259 | if (mul_params.multiplier_fixedpoint_perchannel()) { |
260 | for (int i = user_size; i < user_capacity; i++) { |
261 | RUY_DCHECK_EQ(mul_params.multiplier_fixedpoint_perchannel()[i], 0); |
262 | } |
263 | } |
264 | if (mul_params.multiplier_exponent_perchannel()) { |
265 | for (int i = user_size; i < user_capacity; i++) { |
266 | RUY_DCHECK_EQ(mul_params.multiplier_exponent_perchannel()[i], 0); |
267 | } |
268 | } |
269 | #else |
270 | (void)mul_params; |
271 | (void)user_size; |
272 | (void)user_capacity; |
273 | #endif |
274 | } |
275 | |
276 | template <typename AccumScalar, typename DstScalar, |
277 | bool HaveQuantizedMultipliers = |
278 | std::is_same<AccumScalar, std::int32_t>::value && |
279 | !std::is_same<DstScalar, std::int32_t>::value> |
280 | struct EnsurePerChannelBuffersLargeEnoughImpl { |
281 | static void Run(const TrMulParams& params, Allocator* allocator, |
282 | MulParams<AccumScalar, DstScalar>* mul_params) { |
283 | const Side channel_side = |
284 | mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs |
285 | : Side::kRhs; |
286 | const int required_capacity = |
287 | params.packed_matrix[channel_side].layout.cols; |
288 | const int user_size = params.src[channel_side].layout.cols; |
289 | const int user_capacity = round_up_pot( |
290 | user_size, mul_params->perchannel_buffers_capacity_rounding()); |
291 | // We should have already checked earlier for the case where |
292 | // user_capacity >= required_capacity. |
293 | RUY_DCHECK_GT(required_capacity, user_capacity); |
294 | if (mul_params->bias()) { |
295 | AccumScalar* new_data = |
296 | allocator->Allocate<AccumScalar>(required_capacity); |
297 | std::memcpy(new_data, mul_params->bias(), |
298 | user_size * sizeof(AccumScalar)); |
299 | std::memset(new_data + user_size, 0, |
300 | (required_capacity - user_size) * sizeof(AccumScalar)); |
301 | mul_params->set_bias(new_data); |
302 | } |
303 | if (mul_params->multiplier_fixedpoint_perchannel()) { |
304 | AccumScalar* new_data = |
305 | allocator->Allocate<AccumScalar>(required_capacity); |
306 | std::memcpy(new_data, mul_params->multiplier_fixedpoint_perchannel(), |
307 | user_size * sizeof(AccumScalar)); |
308 | std::memset(new_data + user_size, 0, |
309 | (required_capacity - user_size) * sizeof(AccumScalar)); |
310 | mul_params->set_multiplier_fixedpoint_perchannel(new_data); |
311 | } |
312 | if (mul_params->multiplier_exponent_perchannel()) { |
313 | int* new_data = allocator->Allocate<int>(required_capacity); |
314 | std::memcpy(new_data, mul_params->multiplier_exponent_perchannel(), |
315 | user_size * sizeof(int)); |
316 | std::memset(new_data + user_size, 0, |
317 | (required_capacity - user_size) * sizeof(int)); |
318 | mul_params->set_multiplier_exponent_perchannel(new_data); |
319 | } |
320 | } |
321 | }; |
322 | |
323 | template <typename AccumScalar, typename DstScalar> |
324 | struct EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar, false> { |
325 | static void Run(const TrMulParams& params, Allocator* allocator, |
326 | MulParams<AccumScalar, DstScalar>* mul_params) { |
327 | const Side channel_side = |
328 | mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs |
329 | : Side::kRhs; |
330 | const int required_capacity = |
331 | params.packed_matrix[channel_side].layout.cols; |
332 | const int user_size = params.src[channel_side].layout.cols; |
333 | const int user_capacity = round_up_pot( |
334 | user_size, mul_params->perchannel_buffers_capacity_rounding()); |
335 | // We should have already checked earlier for the case where |
336 | // user_capacity >= required_capacity. |
337 | RUY_DCHECK_GT(required_capacity, user_capacity); |
338 | if (mul_params->bias()) { |
339 | AccumScalar* new_data = |
340 | allocator->Allocate<AccumScalar>(required_capacity); |
341 | std::memcpy(new_data, mul_params->bias(), |
342 | user_size * sizeof(AccumScalar)); |
343 | std::memset(new_data + user_size, 0, |
344 | (required_capacity - user_size) * sizeof(AccumScalar)); |
345 | mul_params->set_bias(new_data); |
346 | } |
347 | } |
348 | }; |
349 | |
350 | template <typename AccumScalar, typename DstScalar> |
351 | void EnsurePerChannelBuffersLargeEnough( |
352 | const TrMulParams& params, Ctx* ctx, |
353 | MulParams<AccumScalar, DstScalar>* mul_params) { |
354 | // Early exit in the common case where the packed matrix size matches the |
355 | // number of channels (as opposed to having been rounded up to a slightly |
356 | // larger value). |
357 | const Side channel_side = |
358 | mul_params->channel_dimension() == ChannelDimension::kRow ? Side::kLhs |
359 | : Side::kRhs; |
360 | const int required_capacity = params.packed_matrix[channel_side].layout.cols; |
361 | const int user_size = params.src[channel_side].layout.cols; |
362 | const int user_capacity = round_up_pot( |
363 | user_size, mul_params->perchannel_buffers_capacity_rounding()); |
364 | AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized( |
365 | *mul_params, user_size, user_capacity); |
366 | if (required_capacity <= user_capacity) { |
367 | return; |
368 | } |
369 | ctx->set_performance_advisory( |
370 | PerformanceAdvisory::kReallocatedPerChannelBuffer); |
371 | EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar>::Run( |
372 | params, ctx->GetMainAllocator(), mul_params); |
373 | } |
374 | |
375 | // Ensures that `params->mul_params_bytes` contains MulParams data that's ready |
376 | // to be consumed by the kernel. As a first-order approximation, that is simply |
377 | // copying the user-provided `mul_params`, however there are a few changes. |
378 | // |
379 | // 1. The specified `channel_dimension` value overrides the channel_dimension |
380 | // member in `mul_params`. The reason why `channel_dimension` is being |
381 | // special-cased among MulParams members is that we will need to transpose |
382 | // MulParams, and that consists just in toggling channel_dimension. |
383 | // 2. Per-channel buffers may be reallocated, see |
384 | // EnsurePerChannelBuffersLargeEnough. |
385 | template <typename AccumScalar, typename DstScalar> |
386 | void FinalizeMulParams(const MulParams<AccumScalar, DstScalar>& mul_params, |
387 | ChannelDimension channel_dimension, Ctx* ctx, |
388 | TrMulParams* params) { |
389 | using MulParamsType = MulParams<AccumScalar, DstScalar>; |
390 | static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, "" ); |
391 | static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, "" ); |
392 | static_assert(std::is_trivially_copyable<MulParamsType>::value, "" ); |
393 | auto* dst_mul_params = |
394 | reinterpret_cast<MulParamsType*>(params->mul_params_bytes); |
395 | std::memcpy(dst_mul_params, &mul_params, sizeof(MulParamsType)); |
396 | dst_mul_params->set_channel_dimension(channel_dimension); |
397 | EnsurePerChannelBuffersLargeEnough(*params, ctx, dst_mul_params); |
398 | } |
399 | |
400 | // In this function, the `channel_dimension` parameter overrides the value |
401 | // of the channel_dimension member in the `mul_params` parameter. See the |
402 | // FinalizeMulParams comment. |
403 | template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
404 | typename AccumScalar, typename DstScalar> |
405 | void CreateTrMulParamsAssumingColMajorDst( |
406 | const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, |
407 | const Mat<DstScalar>& dst, |
408 | const MulParams<AccumScalar, DstScalar>& mul_params, |
409 | ChannelDimension channel_dimension, Ctx* ctx, TrMulParams* params) { |
410 | RUY_TRACE_SCOPE; |
411 | RUY_DCHECK(IsColMajor(dst.layout)); |
412 | |
413 | // Fill in the fields we already know. |
414 | params->src[Side::kLhs] = EraseType(lhs); |
415 | params->src[Side::kRhs] = EraseType(rhs); |
416 | params->dst = EraseType(dst); |
417 | |
418 | // Determine which exact Path we're going to take in this Mul call. |
419 | // This is cheap because it's cached in `ctx`. In user scenarios this always |
420 | // evaluates to the same value on a given machine with given `CompiledPaths`, |
421 | // but could be invalidated by a call to Ctx::SetRuntimeEnabledPaths(), which |
422 | // might be exposed publicly in Context in the future. |
423 | const Path the_path = ctx->SelectPath(CompiledPaths); |
424 | |
425 | RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST); |
426 | |
427 | // If we ever need again to fall back to Path::kStandardCpp, this is a good |
428 | // place to do it -- just pass Path::kStandardCpp as both the template and |
429 | // runtime parameters in this function call. |
430 | // In the past we did that here (as version control history remembers). |
431 | // A typical reason why we might need to resurrect that is if we implement |
432 | // a new Path (i.e. port to a new ISA) and need to subdivide that work into |
433 | // a series of incremental changes. |
434 | PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar, |
435 | AccumScalar, DstScalar>(the_path, params); |
436 | |
437 | // This must be done last, as it depends on the specific choice of kernel. |
438 | // Specifically, the EnsurePerChannelBuffersLargeEnough part of this will read |
439 | // the packed matrix layouts that are written to `params` by the above |
440 | // PopulateTrMulParams* call. |
441 | FinalizeMulParams(mul_params, channel_dimension, ctx, params); |
442 | } |
443 | |
444 | } // namespace detail |
445 | |
446 | inline ChannelDimension Transpose(ChannelDimension channel_dimension) { |
447 | return channel_dimension == ChannelDimension::kCol ? ChannelDimension::kRow |
448 | : ChannelDimension::kCol; |
449 | } |
450 | |
451 | // CreateTrMulParams's output is a TrMulParams object that encodes |
452 | // all of the input information required_capacity by the middle-end, that is, |
453 | // the TrMul function. |
454 | // |
455 | // CreateTrMulParams performs the following tasks: |
456 | // 1. Reduce to the case of column-major destination, by transposing the |
457 | // whole problem as needed. |
458 | // 2. Select the single code path to be taken, out of the set of paths |
459 | // described by the `CompiledPaths` template parameter, based on the |
460 | // runtime input parameter `the_path`. |
461 | // 3. Perform type-erasure, converting templatized typed input parameters |
462 | // to the un-typed data stored in TrMulParams. |
463 | template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
464 | typename AccumScalar, typename DstScalar> |
465 | void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, |
466 | const Mat<DstScalar>& dst, |
467 | const MulParams<AccumScalar, DstScalar>& mul_params, |
468 | Ctx* ctx, TrMulParams* params) { |
469 | RUY_TRACE_SCOPE; |
470 | ChannelDimension channel_dimension = mul_params.channel_dimension(); |
471 | if (IsColMajor(dst.layout)) { |
472 | detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>( |
473 | lhs, rhs, dst, mul_params, channel_dimension, ctx, params); |
474 | } else { |
475 | RUY_TRACE_INFO(CREATE_TRMUL_PARAMS_TRANSPOSING); |
476 | detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>( |
477 | rhs, lhs, Transpose(dst), mul_params, Transpose(channel_dimension), ctx, |
478 | params); |
479 | } |
480 | } |
481 | |
482 | } // namespace ruy |
483 | |
484 | #endif // RUY_RUY_CREATE_TRMUL_PARAMS_H_ |
485 | |