1 | // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // unpack.h: unpacking the result blocks computed by compute.h, |
16 | // storing them into the destination matrix. |
17 | |
18 | #ifndef GEMMLOWP_INTERNAL_UNPACK_H_ |
19 | #define GEMMLOWP_INTERNAL_UNPACK_H_ |
20 | |
21 | #include "allocator.h" |
22 | #include "block_params.h" |
23 | #include "output.h" |
24 | #include "pack.h" |
25 | |
26 | #include <cmath> |
27 | |
28 | namespace gemmlowp { |
29 | |
30 | class PackedResult { |
31 | public: |
32 | PackedResult(Allocator* _allocator, const BlockParams& _block_params) |
33 | : allocator_(_allocator), block_params_(_block_params) { |
34 | matrix_handle_ = allocator_->Reserve<std::int32_t>(block_params_.l2_rows * |
35 | block_params_.l2_cols); |
36 | } |
37 | |
38 | ~PackedResult() {} |
39 | |
40 | MatrixMap<std::int32_t, MapOrder::ColMajor> Map() { |
41 | return MatrixMap<std::int32_t, MapOrder::ColMajor>( |
42 | allocator_->GetPointer<std::int32_t>(matrix_handle_), |
43 | block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows); |
44 | } |
45 | |
46 | MatrixMap<const std::int32_t, MapOrder::ColMajor> Map() const { |
47 | return MatrixMap<const std::int32_t, MapOrder::ColMajor>( |
48 | allocator_->GetPointer<const std::int32_t>(matrix_handle_), |
49 | block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows); |
50 | } |
51 | |
52 | private: |
53 | Allocator* allocator_; |
54 | Allocator::Handle matrix_handle_; |
55 | const BlockParams& block_params_; |
56 | }; |
57 | |
58 | struct MatrixBlockBounds { |
59 | int start_row; |
60 | int start_col; |
61 | int rows; |
62 | int cols; |
63 | |
64 | MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_) |
65 | : start_row(start_row_), |
66 | start_col(start_col_), |
67 | rows(rows_), |
68 | cols(cols_) {} |
69 | }; |
70 | |
71 | template <int Rows, int Cols, typename SrcMapType> |
72 | void PrefetchResultBlock(const SrcMapType& src, |
73 | const VectorMap<const std::int32_t, VectorShape::Col>& |
74 | lhs_sums_of_each_slice, |
75 | int src_row, int src_col) { |
76 | const std::int32_t* src_data = src.data(src_row, src_col); |
77 | const int src_stride = src.stride(); |
78 | const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row); |
79 | for (int r = 0; r < Rows; r += 4) { |
80 | Prefetch(lhs_sums_data + r); |
81 | } |
82 | for (int c = 0; c < Cols; c++) { |
83 | for (int r = 0; r < Rows; r += 4) { |
84 | Prefetch(src_data + r + c * src_stride); |
85 | } |
86 | } |
87 | } |
88 | |
89 | template <typename KernelFormat, typename RegisterBlockType, |
90 | typename SrcMapType, typename LhsOffset, typename RhsOffset, |
91 | typename OutputPipelineExecutorType, typename DstType> |
92 | void UnpackResultBlock(const SrcMapType& src, |
93 | const OutputPipelineExecutorType& executor, DstType* dst, |
94 | const VectorMap<const std::int32_t, VectorShape::Col>& |
95 | lhs_sums_of_each_slice, |
96 | const VectorMap<const std::int32_t, VectorShape::Row>& |
97 | rhs_sums_of_each_slice, |
98 | const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, |
99 | int depth, int src_row, int src_col, int src_global_row, |
100 | int src_global_col, int dst_row, int dst_col) { |
101 | using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar; |
102 | using KernelLhsScalar = typename KernelFormat::Lhs::Scalar; |
103 | using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar; |
104 | using KernelRhsScalar = typename KernelFormat::Rhs::Scalar; |
105 | static constexpr int KernelLhsZeroPointInput = |
106 | ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue; |
107 | static constexpr int KernelRhsZeroPointInput = |
108 | ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue; |
109 | auto acc = Load<RegisterBlockType>(src, src_row, src_col); |
110 | const auto& lhs_sums_of_each_slice_block = |
111 | LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row); |
112 | const auto& rhs_sums_of_each_slice_block = |
113 | LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col); |
114 | auto lhs_offset_block = |
115 | LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row); |
116 | auto rhs_offset_block = |
117 | LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col); |
118 | AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block); |
119 | AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block); |
120 | BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc); |
121 | for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) { |
122 | rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth); |
123 | } |
124 | BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block), |
125 | lhs_offset_block, &acc); |
126 | executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col); |
127 | } |
128 | |
129 | template <typename KernelFormat, typename ResultBlockType, |
130 | typename PackedResultType, typename LhsOffset, typename RhsOffset, |
131 | typename OutputPipelineType> |
132 | void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block, |
133 | const PackedResultType& src, int depth, |
134 | const std::int32_t* lhs_sums_of_each_slice_ptr, |
135 | const std::int32_t* rhs_sums_of_each_slice_ptr, |
136 | const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, |
137 | const OutputPipelineType& output_pipeline) { |
138 | ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor |
139 | ? "unpack to column-major" |
140 | : "unpack to row-major" ); |
141 | assert(dst_block.start_row >= 0); |
142 | assert(dst_block.start_row + dst_block.rows <= dst->rows()); |
143 | assert(dst_block.start_col >= 0); |
144 | assert(dst_block.start_col + dst_block.cols <= dst->cols()); |
145 | const auto src_map = src.Map(); |
146 | const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice( |
147 | lhs_sums_of_each_slice_ptr, dst_block.rows); |
148 | const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice( |
149 | rhs_sums_of_each_slice_ptr, dst_block.cols); |
150 | using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>; |
151 | using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>; |
152 | using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>; |
153 | using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>; |
154 | using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>; |
155 | using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>; |
156 | |
157 | using DstScalarType = typename ResultBlockType::Scalar; |
158 | using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>; |
159 | |
160 | OutputPipelineExecutor<OutputPipelineType, Int32x1x1> |
161 | output_pipeline_executor_1x1(output_pipeline); |
162 | OutputPipelineExecutor<OutputPipelineType, Int32x4x1> |
163 | output_pipeline_executor_4x1(output_pipeline); |
164 | OutputPipelineExecutor<OutputPipelineType, Int32x8x1> |
165 | output_pipeline_executor_8x1(output_pipeline); |
166 | OutputPipelineExecutor<OutputPipelineType, Int32x1x4> |
167 | output_pipeline_executor_1x4(output_pipeline); |
168 | OutputPipelineExecutor<OutputPipelineType, Int32x4x4> |
169 | output_pipeline_executor_4x4(output_pipeline); |
170 | OutputPipelineExecutor<OutputPipelineType, Int32x8x4> |
171 | output_pipeline_executor_8x4(output_pipeline); |
172 | |
173 | int c8 = 0; |
174 | if (ResultBlockType::kOrder == MapOrder::RowMajor) { |
175 | for (; c8 <= dst_block.cols - 8; c8 += 8) { |
176 | PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8); |
177 | int r = 0; |
178 | for (; r <= dst_block.rows - 8; r += 8) { |
179 | const int global_row = r + dst_block.start_row; |
180 | PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8); |
181 | DstScalarType dst_colmajor_buf[64]; |
182 | MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map( |
183 | dst_colmajor_buf, 8, 8); |
184 | for (int cx = 0; cx < 8; cx += 4) { |
185 | const int c = c8 + cx; |
186 | const int global_col = c + dst_block.start_col; |
187 | UnpackResultBlock<KernelFormat, Int32x8x4>( |
188 | src_map, output_pipeline_executor_8x4, &dst_colmajor_map, |
189 | lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset, |
190 | rhs_offset, depth, r, c, global_row, global_col, 0, cx); |
191 | } |
192 | StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst, |
193 | r + dst_block.start_row, c8 + dst_block.start_col); |
194 | } |
195 | for (; r <= dst_block.rows - 4; r += 4) { |
196 | const int global_row = r + dst_block.start_row; |
197 | for (int cx = 0; cx < 8; cx += 4) { |
198 | const int c = c8 + cx; |
199 | const int global_col = c + dst_block.start_col; |
200 | UnpackResultBlock<KernelFormat, Int32x4x4>( |
201 | src_map, output_pipeline_executor_4x4, dst, |
202 | lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset, |
203 | rhs_offset, depth, r, c, global_row, global_col, global_row, |
204 | global_col); |
205 | } |
206 | } |
207 | for (; r < dst_block.rows; r++) { |
208 | const int global_row = r + dst_block.start_row; |
209 | for (int cx = 0; cx < 8; cx += 4) { |
210 | const int c = c8 + cx; |
211 | const int global_col = c + dst_block.start_col; |
212 | UnpackResultBlock<KernelFormat, Int32x1x4>( |
213 | src_map, output_pipeline_executor_1x4, dst, |
214 | lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset, |
215 | rhs_offset, depth, r, c, global_row, global_col, global_row, |
216 | global_col); |
217 | } |
218 | } |
219 | } |
220 | } |
221 | int c = c8; |
222 | for (; c <= dst_block.cols - 4; c += 4) { |
223 | const int global_col = c + dst_block.start_col; |
224 | PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c); |
225 | int r = 0; |
226 | for (; r <= dst_block.rows - 8; r += 8) { |
227 | const int global_row = r + dst_block.start_row; |
228 | PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c); |
229 | UnpackResultBlock<KernelFormat, Int32x8x4>( |
230 | src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice, |
231 | rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, |
232 | global_row, global_col, global_row, global_col); |
233 | } |
234 | for (; r <= dst_block.rows - 4; r += 4) { |
235 | const int global_row = r + dst_block.start_row; |
236 | UnpackResultBlock<KernelFormat, Int32x4x4>( |
237 | src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice, |
238 | rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, |
239 | global_row, global_col, global_row, global_col); |
240 | } |
241 | for (; r < dst_block.rows; r++) { |
242 | const int global_row = r + dst_block.start_row; |
243 | UnpackResultBlock<KernelFormat, Int32x1x4>( |
244 | src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice, |
245 | rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, |
246 | global_row, global_col, global_row, global_col); |
247 | } |
248 | } |
249 | for (; c < dst_block.cols; c++) { |
250 | const int global_col = c + dst_block.start_col; |
251 | PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c); |
252 | int r = 0; |
253 | for (; r <= dst_block.rows - 8; r += 8) { |
254 | const int global_row = r + dst_block.start_row; |
255 | PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c); |
256 | UnpackResultBlock<KernelFormat, Int32x8x1>( |
257 | src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice, |
258 | rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, |
259 | global_row, global_col, global_row, global_col); |
260 | } |
261 | for (; r <= dst_block.rows - 4; r += 4) { |
262 | const int global_row = r + dst_block.start_row; |
263 | UnpackResultBlock<KernelFormat, Int32x4x1>( |
264 | src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice, |
265 | rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, |
266 | global_row, global_col, global_row, global_col); |
267 | } |
268 | for (; r < dst_block.rows; r++) { |
269 | const int global_row = r + dst_block.start_row; |
270 | UnpackResultBlock<KernelFormat, Int32x1x1>( |
271 | src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice, |
272 | rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, |
273 | global_row, global_col, global_row, global_col); |
274 | } |
275 | } |
276 | } |
277 | |
278 | } // end namespace gemmlowp |
279 | |
280 | #endif // GEMMLOWP_INTERNAL_UNPACK_H_ |
281 | |