1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_BLOCK_MAP_H_
17#define RUY_RUY_BLOCK_MAP_H_
18
19#include "ruy/cpu_cache_params.h"
20#include "ruy/side_pair.h"
21
22namespace ruy {
23
24enum class BlockMapTraversalOrder {
25 // Plain old row-by-row or column-by-column traversal.
26 kLinear,
27 // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve
28 kFractalZ,
29 // Variant of Z-order doing a U instead of a Z.
30 kFractalU,
31 // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve
32 kFractalHilbert
33};
34
35// A BlockMap describes a tiling of a matrix, typically the destination matrix
36// of a matrix multiplication computation. As is standard in matrix
37// multiplication, a tile is called a "block".
38//
39// Ruy subdivides work by blocks of the destination matrix: each thread fully
40// computes a block at once, then moves on to another block; each block is
41// produced by a single thread.
42//
43// This ensures that the workloads for each block are mutually independent,
44// which reduces synchronization requirements.
45//
46// Typically, a matrix multiplication will early on create a BlockMap by
47// calling MakeBlockMap. It will then query the number of blocks in that
48// BlockMap by calling NumBlocks. It will then create a single atomic integer
49// counter indexing these blocks, called the 'index', and will distribute
50// work to its N threads by ensuring that each thread works on disjoint sets
51// of index values. For a given index value, the thread will call
52// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords
53// to find the actual row and column numbers of this block.
54//
55// There are two nested levels of subdivision. On a local level, the matrix is
56// tiled into a square NxN grid where N is a power of two, specifically:
57// N = 2^num_blocks_base_log2.
58//
59// At a larger scale, around these blocks, there may be one further
60// level of subdivision, in only one dimension: either along rows or along
61// columns. That is used to handle arbitrarily rectangular matrices. The
62// aforementioned high-level block grid is square, so it does not readily fit
63// well very rectangular matrices.
64//
65// Taking together these two nested levels of subdivision, the effective
66// tiling is by
67// 2^(num_blocks_base_log2 + rows_rectangularness_log2)
68// blocks in the row dimension, and by
69// 2^(num_blocks_base_log2 + cols_rectangularness_log2)
70// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols.
71//
72// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero.
73//
74// Finally, this BlockMap is designed to operate under alignment constraints:
75// two fields, kernel_rows and kernel_cols, describe the requested alignment
76// of the effective grid in both dimensions. The idea is to feed matrix
77// multiplication kernels with tiles that fit their width as much as possible.
78// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp.
79// kernel_cols) then some tile will have to have unaligned size. BlockMap
80// will only allow that to happen in the last position along each axis, so
81// as to minimize the overhead incurred onto the matrix multiplication kernels.
82struct BlockMap {
83 // The number of threads to use (to distribute the blocks to).
84 int thread_count;
85 // The order in which to traverse the matrix of which this BlockMap represents
86 // a tiling (hereafter "the matrix").
87 BlockMapTraversalOrder traversal_order;
88 // The dimensions of the block_map, that is, of the destination
89 // matrix rounded up to next multiples of kernel_dims.
90 SidePair<int> dims;
91 // Log2 of the minimum number of subdivisions of the grid along either axis.
92 int num_blocks_base_log2;
93 // Log2 of the additional subdivision of the rows/columns axis.
94 SidePair<int> rectangularness_log2;
95 // Requested alignment of the subdivisions of the grid along the rows/columns
96 // axis.
97 SidePair<int> kernel_dims;
98 // Internal helper. Minimum number of rows/columns in each block.
99 SidePair<int> small_block_dims;
100 // Internal helper. Number of blocks along each dimension that need to have
101 // their size in that dimension be given by (small_block_dims + kernel_dims)
102 // instead of just small_block_dims.
103 SidePair<int> large_blocks;
104};
105
106// This function produces a coarse estimate of whether linear traversal will
107// be used for this matmul. It offers a one-way guarantee: if this function
108// returns true then linear traversal will be used.
109//
110// The purpose of this function is to allow TrMul to make a cheap, early
111// decision to enter a "simple loop" code path for simple cases.
112bool IsObviouslyLinearTraversal(int rows, int cols, int depth,
113 int lhs_scalar_size, int rhs_scalar_size,
114 const CpuCacheParams& cpu_cache_params);
115
116// Create a BlockMap suitable for tiling the destination matrix in a
117// matrix multiplication with the given parameters.
118void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
119 int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
120 int tentative_thread_count,
121 const CpuCacheParams& cpu_cache_params, BlockMap* block_map);
122
123// Maps an integer index to a block position in the grid.
124void GetBlockByIndex(const BlockMap& block_map, int index,
125 SidePair<int>* block);
126
127// Given a block position in the grid, returns its actual
128// position in the matrix that the BlockMap refers to in the dimension
129// referred to by `side`: along rows if side==kLhs, along columns if
130// side==kRhs.
131void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
132 int* start, int* end);
133
134// Given a block position in the grid, returns its actual
135// position in the matrix that the BlockMap refers to in terms of
136// actual row/column indices.
137void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
138 SidePair<int>* start, SidePair<int>* end);
139
140// Returns the number of grid subdivisions along the rows dimension (if
141// side == kLhs) or columns dimension (if side == kRhs).
142inline int NumBlocksPerSide(Side side, const BlockMap& block_map) {
143 return 1 << (block_map.num_blocks_base_log2 +
144 block_map.rectangularness_log2[side]);
145}
146
147// Returns the overall number of blocks in
148// the BlockMap. The valid index values to pass to GetBlockByIndex are the
149// integers from 0 to N-1 where N is the value returned here.
150//
151// Note that it is always true that
152// NumBlocks == NumBlocksOfRows * NumBlocksOfCols
153// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0.
154inline int NumBlocks(const BlockMap& block_map) {
155 return 1 << (2 * block_map.num_blocks_base_log2 +
156 block_map.rectangularness_log2[Side::kLhs] +
157 block_map.rectangularness_log2[Side::kRhs]);
158}
159
160} // namespace ruy
161
162#endif // RUY_RUY_BLOCK_MAP_H_
163