1 | // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // kernel_reference.h: a reference kernel for CPU architectures where we don't |
16 | // have optimized kernels yet. Also useful for testing, as it's templatized |
17 | // to have any arbitrary format, allowing tests to cover all sorts of corner |
18 | // cases. |
19 | |
20 | #ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ |
21 | #define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ |
22 | |
23 | #include "kernel.h" |
24 | |
25 | #include <cstdio> |
26 | #include <cstring> |
27 | |
28 | namespace gemmlowp { |
29 | |
30 | // This kernel is templatized in an arbitrary Format template parameter, |
31 | // allowing it to have any arbitrary format. |
32 | template <typename tFormat> |
33 | struct ReferenceKernel : KernelBase { |
34 | typedef tFormat Format; |
35 | |
36 | const char* Name() const override { |
37 | static char buf[256]; |
38 | snprintf(buf, sizeof(buf), |
39 | "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)" , |
40 | Format::Lhs::kCells, Format::Lhs::Cell::kWidth, |
41 | Format::Lhs::Cell::kDepth, |
42 | CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells, |
43 | Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth, |
44 | CellOrderName(Format::Rhs::Cell::kOrder)); |
45 | return buf; |
46 | } |
47 | |
48 | void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, |
49 | std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, |
50 | const std::uint8_t* rhs_ptr, std::size_t start_depth, |
51 | std::size_t run_depth) const override { |
52 | std::int32_t accumulator[Format::kRows * Format::kCols]; |
53 | memset(accumulator, 0, sizeof(accumulator)); |
54 | |
55 | const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth); |
56 | |
57 | // The outer loop is over the depth dimension. |
58 | for (int dc = 0; dc < run_depth_cells; dc++) { |
59 | // The next two loops are over cells of the Lhs (stacked vertically), |
60 | // and over cells of the Rhs (stacked horizontally). |
61 | for (int rc = 0; rc < Format::Lhs::kCells; rc++) { |
62 | const std::uint8_t* lhs_cell_ptr = |
63 | lhs_ptr + (dc * Format::Lhs::kCells + rc) * |
64 | Format::Lhs::Cell::kWidth * Format::kDepth; |
65 | for (int cc = 0; cc < Format::Rhs::kCells; cc++) { |
66 | const std::uint8_t* rhs_cell_ptr = |
67 | rhs_ptr + (dc * Format::Rhs::kCells + cc) * |
68 | Format::Rhs::Cell::kWidth * Format::kDepth; |
69 | |
70 | // Now we are inside one cell of the Lhs and inside one cell |
71 | // of the Rhs, so the remaining inner loops are just |
72 | // traditional three loops of matrix multiplication. |
73 | for (int di = 0; di < Format::kDepth; di++) { |
74 | for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) { |
75 | for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) { |
76 | const std::uint8_t* lhs_coeff_ptr = |
77 | lhs_cell_ptr + |
78 | OffsetIntoCell<typename Format::Lhs::Cell>(ri, di); |
79 | const std::uint8_t* rhs_coeff_ptr = |
80 | rhs_cell_ptr + |
81 | OffsetIntoCell<typename Format::Rhs::Cell>(ci, di); |
82 | std::int32_t* accumulator_coeff_ptr = |
83 | accumulator + (ri + rc * Format::Lhs::Cell::kWidth) + |
84 | (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows; |
85 | *accumulator_coeff_ptr += |
86 | std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr); |
87 | } |
88 | } |
89 | } |
90 | } |
91 | } |
92 | } |
93 | |
94 | if (start_depth == 0) { |
95 | // start_depth == 0 means we haven't accumulated anything yet, so we need |
96 | // to overwrite the accumulator, as it hasn't been initialized to zero. |
97 | for (int r = 0; r < Format::kRows; r++) { |
98 | for (int c = 0; c < Format::kCols; c++) { |
99 | dst_ptr[r * dst_row_stride + c * dst_col_stride] = |
100 | accumulator[r + c * Format::kRows]; |
101 | } |
102 | } |
103 | } else { |
104 | // We have already accumulated stuff, so we need to continue accumulating |
105 | // instead of just overwriting. |
106 | for (int r = 0; r < Format::kRows; r++) { |
107 | for (int c = 0; c < Format::kCols; c++) { |
108 | dst_ptr[r * dst_row_stride + c * dst_col_stride] += |
109 | accumulator[r + c * Format::kRows]; |
110 | } |
111 | } |
112 | } |
113 | } |
114 | }; |
115 | |
116 | } // namespace gemmlowp |
117 | |
118 | #endif // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ |
119 | |