1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_BLOCK_MAP_H_ |
17 | #define RUY_RUY_BLOCK_MAP_H_ |
18 | |
19 | #include "ruy/cpu_cache_params.h" |
20 | #include "ruy/side_pair.h" |
21 | |
22 | namespace ruy { |
23 | |
24 | enum class BlockMapTraversalOrder { |
25 | // Plain old row-by-row or column-by-column traversal. |
26 | kLinear, |
27 | // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve |
28 | kFractalZ, |
29 | // Variant of Z-order doing a U instead of a Z. |
30 | kFractalU, |
31 | // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve |
32 | kFractalHilbert |
33 | }; |
34 | |
35 | // A BlockMap describes a tiling of a matrix, typically the destination matrix |
36 | // of a matrix multiplication computation. As is standard in matrix |
37 | // multiplication, a tile is called a "block". |
38 | // |
39 | // Ruy subdivides work by blocks of the destination matrix: each thread fully |
40 | // computes a block at once, then moves on to another block; each block is |
41 | // produced by a single thread. |
42 | // |
43 | // This ensures that the workloads for each block are mutually independent, |
44 | // which reduces synchronization requirements. |
45 | // |
46 | // Typically, a matrix multiplication will early on create a BlockMap by |
47 | // calling MakeBlockMap. It will then query the number of blocks in that |
48 | // BlockMap by calling NumBlocks. It will then create a single atomic integer |
49 | // counter indexing these blocks, called the 'index', and will distribute |
50 | // work to its N threads by ensuring that each thread works on disjoint sets |
51 | // of index values. For a given index value, the thread will call |
52 | // GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords |
53 | // to find the actual row and column numbers of this block. |
54 | // |
55 | // There are two nested levels of subdivision. On a local level, the matrix is |
56 | // tiled into a square NxN grid where N is a power of two, specifically: |
57 | // N = 2^num_blocks_base_log2. |
58 | // |
59 | // At a larger scale, around these blocks, there may be one further |
60 | // level of subdivision, in only one dimension: either along rows or along |
61 | // columns. That is used to handle arbitrarily rectangular matrices. The |
62 | // aforementioned high-level block grid is square, so it does not readily fit |
63 | // well very rectangular matrices. |
64 | // |
65 | // Taking together these two nested levels of subdivision, the effective |
66 | // tiling is by |
67 | // 2^(num_blocks_base_log2 + rows_rectangularness_log2) |
68 | // blocks in the row dimension, and by |
69 | // 2^(num_blocks_base_log2 + cols_rectangularness_log2) |
70 | // blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. |
71 | // |
72 | // Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. |
73 | // |
74 | // Finally, this BlockMap is designed to operate under alignment constraints: |
75 | // two fields, kernel_rows and kernel_cols, describe the requested alignment |
76 | // of the effective grid in both dimensions. The idea is to feed matrix |
77 | // multiplication kernels with tiles that fit their width as much as possible. |
78 | // Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. |
79 | // kernel_cols) then some tile will have to have unaligned size. BlockMap |
80 | // will only allow that to happen in the last position along each axis, so |
81 | // as to minimize the overhead incurred onto the matrix multiplication kernels. |
82 | struct BlockMap { |
83 | // The number of threads to use (to distribute the blocks to). |
84 | int thread_count; |
85 | // The order in which to traverse the matrix of which this BlockMap represents |
86 | // a tiling (hereafter "the matrix"). |
87 | BlockMapTraversalOrder traversal_order; |
88 | // The dimensions of the block_map, that is, of the destination |
89 | // matrix rounded up to next multiples of kernel_dims. |
90 | SidePair<int> dims; |
91 | // Log2 of the minimum number of subdivisions of the grid along either axis. |
92 | int num_blocks_base_log2; |
93 | // Log2 of the additional subdivision of the rows/columns axis. |
94 | SidePair<int> rectangularness_log2; |
95 | // Requested alignment of the subdivisions of the grid along the rows/columns |
96 | // axis. |
97 | SidePair<int> kernel_dims; |
98 | // Internal helper. Minimum number of rows/columns in each block. |
99 | SidePair<int> small_block_dims; |
100 | // Internal helper. Number of blocks along each dimension that need to have |
101 | // their size in that dimension be given by (small_block_dims + kernel_dims) |
102 | // instead of just small_block_dims. |
103 | SidePair<int> large_blocks; |
104 | }; |
105 | |
106 | // This function produces a coarse estimate of whether linear traversal will |
107 | // be used for this matmul. It offers a one-way guarantee: if this function |
108 | // returns true then linear traversal will be used. |
109 | // |
110 | // The purpose of this function is to allow TrMul to make a cheap, early |
111 | // decision to enter a "simple loop" code path for simple cases. |
112 | bool IsObviouslyLinearTraversal(int rows, int cols, int depth, |
113 | int lhs_scalar_size, int rhs_scalar_size, |
114 | const CpuCacheParams& cpu_cache_params); |
115 | |
116 | // Create a BlockMap suitable for tiling the destination matrix in a |
117 | // matrix multiplication with the given parameters. |
118 | void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, |
119 | int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, |
120 | int tentative_thread_count, |
121 | const CpuCacheParams& cpu_cache_params, BlockMap* block_map); |
122 | |
123 | // Maps an integer index to a block position in the grid. |
124 | void GetBlockByIndex(const BlockMap& block_map, int index, |
125 | SidePair<int>* block); |
126 | |
127 | // Given a block position in the grid, returns its actual |
128 | // position in the matrix that the BlockMap refers to in the dimension |
129 | // referred to by `side`: along rows if side==kLhs, along columns if |
130 | // side==kRhs. |
131 | void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, |
132 | int* start, int* end); |
133 | |
134 | // Given a block position in the grid, returns its actual |
135 | // position in the matrix that the BlockMap refers to in terms of |
136 | // actual row/column indices. |
137 | void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block, |
138 | SidePair<int>* start, SidePair<int>* end); |
139 | |
140 | // Returns the number of grid subdivisions along the rows dimension (if |
141 | // side == kLhs) or columns dimension (if side == kRhs). |
142 | inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { |
143 | return 1 << (block_map.num_blocks_base_log2 + |
144 | block_map.rectangularness_log2[side]); |
145 | } |
146 | |
147 | // Returns the overall number of blocks in |
148 | // the BlockMap. The valid index values to pass to GetBlockByIndex are the |
149 | // integers from 0 to N-1 where N is the value returned here. |
150 | // |
151 | // Note that it is always true that |
152 | // NumBlocks == NumBlocksOfRows * NumBlocksOfCols |
153 | // because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. |
154 | inline int NumBlocks(const BlockMap& block_map) { |
155 | return 1 << (2 * block_map.num_blocks_base_log2 + |
156 | block_map.rectangularness_log2[Side::kLhs] + |
157 | block_map.rectangularness_log2[Side::kRhs]); |
158 | } |
159 | |
160 | } // namespace ruy |
161 | |
162 | #endif // RUY_RUY_BLOCK_MAP_H_ |
163 | |