1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implementation of MulFrontEnd, the front-end part of ruy. |
17 | // This is what the ruy::Mul entry point calls, and this ends in a call to |
18 | // TrMul, at which point we enter the middle-end. |
19 | // The front-end work includes parameter validation (Validate), detemplatization |
20 | // and resolution of the specific code path to take (CreateTrMulParams), and |
21 | // any additional logic best done upfront before entering the middle-end |
22 | // (e.g. HandlePrepackedCaching). |
23 | // The call to CreateTrMulParams is an important watershed in this code's |
24 | // structure: code before it needs to be templatized like the ruy::Mul entry |
25 | // point, code after it is un-templatized. |
26 | |
27 | #ifndef RUY_RUY_FRONTEND_H_ |
28 | #define RUY_RUY_FRONTEND_H_ |
29 | |
30 | #include "ruy/create_trmul_params.h" |
31 | #include "ruy/ctx.h" |
32 | #include "ruy/profiler/instrumentation.h" |
33 | #include "ruy/trace.h" |
34 | #include "ruy/trmul_params.h" |
35 | #include "ruy/validate.h" |
36 | |
37 | namespace ruy { |
38 | |
39 | // The first half of front-end work, up to the point where we have TrMulParams. |
40 | // In other words, this is the part of the front-end work that needs to be |
41 | // templatized like the entry point, and that performs the initial work that |
42 | // requires this templatization, and the de-templatization. The output of this |
43 | // function is the TrMulParams, which contain enough information to allow the |
44 | // un-templatized code to take over from there. |
45 | template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
46 | typename AccumScalar, typename DstScalar> |
47 | void MulFrontEndUpToCreateTrMulParams( |
48 | const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, |
49 | const Mat<DstScalar>& dst, |
50 | const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx, |
51 | TrMulParams* params) { |
52 | RUY_TRACE_SCOPE; |
53 | static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path" ); |
54 | static_assert( |
55 | (CompiledPaths & ~kAllPathsIncludingInternalVariants) == Path::kNone, |
56 | "CompiledPaths must be a subset of " |
57 | "ruy::kAllPathsIncludingInternalVariants" ); |
58 | |
59 | // Perform validation of parameters early so that failures are easier to map |
60 | // to user errors. In particular, perform this validation before the |
61 | // transposition. |
62 | Validate(lhs, rhs, dst); |
63 | |
64 | // De-templatize this Mul call by creating a TrMulParams structure. |
65 | // This is also where the specific kernel and pack code paths corresponding to |
66 | // `the_path` are selected, among all the code paths in `CompiledPaths`, and |
67 | // recorded as function pointers in the TrMulParams. |
68 | // The Transpose(lhs) here is where we switch from 'Mul' to 'TrMul'. |
69 | CreateTrMulParams<CompiledPaths>(Transpose(lhs), rhs, dst, mul_params, ctx, |
70 | params); |
71 | } |
72 | |
73 | // The second part of the front-end work, starting from where we have freshly |
74 | // created TrMulParams, performing any remaining front-end work and entering the |
75 | // middle-end. |
76 | void MulFrontEndFromTrMulParams(Ctx* ctx, TrMulParams* params); |
77 | |
78 | // Top-level function orchestrating the two halves of front-end work: |
79 | // before and after we have detemplatized the call by creating TrMulParams. |
80 | template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
81 | typename AccumScalar, typename DstScalar> |
82 | void MulFrontEnd(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, |
83 | const MulParams<AccumScalar, DstScalar>& mul_params, Ctx* ctx, |
84 | Mat<DstScalar>* dst) { |
85 | RUY_TRACE_SCOPE; |
86 | profiler::ScopeLabel mul_label("Mul" ); |
87 | profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d" , |
88 | lhs.layout.rows, lhs.layout.cols, |
89 | rhs.layout.cols); |
90 | ctx->clear_performance_advisories(); |
91 | TrMulParams params; |
92 | MulFrontEndUpToCreateTrMulParams<CompiledPaths>(lhs, rhs, *dst, mul_params, |
93 | ctx, ¶ms); |
94 | MulFrontEndFromTrMulParams(ctx, ¶ms); |
95 | } |
96 | |
97 | } // namespace ruy |
98 | |
99 | #endif // RUY_RUY_FRONTEND_H_ |
100 | |