1 | // Copyright 2017 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // simd_wrappers.h: some inline functions wrapping SIMD intrinsics, |
16 | // extending the set of such functions from fixedpoint.h. |
17 | |
18 | #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ |
19 | #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ |
20 | |
21 | #include <algorithm> |
22 | #include <type_traits> |
23 | #include "../fixedpoint/fixedpoint.h" |
24 | |
25 | namespace gemmlowp { |
26 | |
27 | template <typename ScalarType, int ScalarCount> |
28 | struct RegisterType { |
29 | using Type = ScalarType; |
30 | }; |
31 | |
32 | inline std::int32_t Min(std::int32_t a, std::int32_t b) { |
33 | return std::min(a, b); |
34 | } |
35 | |
36 | inline std::int32_t Max(std::int32_t a, std::int32_t b) { |
37 | return std::max(a, b); |
38 | } |
39 | |
40 | inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) { |
41 | *acc += lhs * rhs; |
42 | } |
43 | |
44 | template <typename tScalarType, int tScalarCount> |
45 | struct RegisterBuffer { |
46 | using ScalarType = tScalarType; |
47 | static constexpr int kScalarCount = tScalarCount; |
48 | using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type; |
49 | static_assert((kScalarCount & (kScalarCount - 1)) == 0, |
50 | "kScalarCount must be a power of two" ); |
51 | static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "" ); |
52 | static constexpr int kRegisterLanes = |
53 | sizeof(RegisterType) / sizeof(ScalarType); |
54 | static constexpr int kRegisterCount = |
55 | (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) / |
56 | sizeof(RegisterType); |
57 | |
58 | RegisterType reg[kRegisterCount]; |
59 | }; |
60 | |
61 | template <typename tScalarType, int tRows, int tCols> |
62 | struct RegisterBlock { |
63 | using ScalarType = tScalarType; |
64 | static constexpr int kRows = tRows; |
65 | static constexpr int kCols = tCols; |
66 | static constexpr int kScalarCount = kRows * kCols; |
67 | using BufferType = RegisterBuffer<ScalarType, kScalarCount>; |
68 | using RegisterType = typename BufferType::RegisterType; |
69 | static constexpr int kRegisterCount = BufferType::kRegisterCount; |
70 | static constexpr int kRegisterLanes = BufferType::kRegisterLanes; |
71 | |
72 | BufferType buf; |
73 | }; |
74 | |
75 | template <typename RegisterBlockType> |
76 | struct RegisterBlockAddImpl { |
77 | static RegisterBlockType Run(const RegisterBlockType& lhs, |
78 | const RegisterBlockType& rhs) { |
79 | RegisterBlockType result; |
80 | for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { |
81 | result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); |
82 | } |
83 | return result; |
84 | } |
85 | }; |
86 | |
87 | template <typename RegisterBlockType> |
88 | RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs, |
89 | const RegisterBlockType& rhs) { |
90 | return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs); |
91 | } |
92 | |
93 | template <typename LhsType, typename RhsType> |
94 | struct ShouldFlipLhsRhs { |
95 | static constexpr bool kValue = |
96 | (LhsType::kScalarCount < RhsType::kScalarCount) || |
97 | (LhsType::kScalarCount == RhsType::kScalarCount && |
98 | (LhsType::kRows < RhsType::kRows)); |
99 | }; |
100 | |
101 | template <typename LhsType, typename RhsType, |
102 | bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue> |
103 | struct FlipLhsRhs { |
104 | using FlippedLhsType = LhsType; |
105 | using FlippedRhsType = RhsType; |
106 | static const FlippedLhsType& FlippedLhs(const LhsType& lhs, |
107 | const RhsType& rhs) { |
108 | (void)rhs; |
109 | return lhs; |
110 | } |
111 | static const FlippedRhsType& FlippedRhs(const LhsType& lhs, |
112 | const RhsType& rhs) { |
113 | (void)lhs; |
114 | return rhs; |
115 | } |
116 | }; |
117 | |
118 | template <typename LhsType, typename RhsType> |
119 | struct FlipLhsRhs<LhsType, RhsType, true> { |
120 | using FlippedLhsType = RhsType; |
121 | using FlippedRhsType = LhsType; |
122 | static const FlippedLhsType& FlippedLhs(const LhsType& lhs, |
123 | const RhsType& rhs) { |
124 | (void)lhs; |
125 | return rhs; |
126 | } |
127 | static const FlippedRhsType& FlippedRhs(const LhsType& lhs, |
128 | const RhsType& rhs) { |
129 | (void)rhs; |
130 | return lhs; |
131 | } |
132 | }; |
133 | |
134 | template <typename Lhs, typename Rhs> |
135 | struct BroadcastBinaryOpShape { |
136 | static constexpr int kRows = |
137 | Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows; |
138 | static constexpr int kCols = |
139 | Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols; |
140 | }; |
141 | |
142 | template <typename Lhs, typename Rhs> |
143 | struct BroadcastBinaryOpRegisterBlock { |
144 | using Shape = BroadcastBinaryOpShape<Lhs, Rhs>; |
145 | using ScalarType = typename Lhs::ScalarType; |
146 | using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; |
147 | }; |
148 | |
149 | template <typename Lhs, typename Rhs> |
150 | struct BroadcastAddImpl { |
151 | using ResultBlockType = |
152 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; |
153 | static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { |
154 | ResultBlockType result; |
155 | static constexpr int Rows = ResultBlockType::kRows; |
156 | static constexpr int Cols = ResultBlockType::kCols; |
157 | static constexpr int LhsRows = Lhs::kRows; |
158 | static constexpr int LhsCols = Lhs::kCols; |
159 | static constexpr int RhsRows = Rhs::kRows; |
160 | static constexpr int RhsCols = Rhs::kCols; |
161 | |
162 | static_assert(LhsRows == Rows || LhsRows == 1, "" ); |
163 | static_assert(RhsRows == Rows || RhsRows == 1, "" ); |
164 | static_assert(LhsCols == Cols || LhsCols == 1, "" ); |
165 | static_assert(RhsCols == Cols || RhsCols == 1, "" ); |
166 | static_assert(ResultBlockType::kRegisterLanes == 1, |
167 | "This path is only for scalar values" ); |
168 | static_assert(Lhs::kRegisterLanes == 1, |
169 | "This path is only for scalar values" ); |
170 | static_assert(Rhs::kRegisterLanes == 1, |
171 | "This path is only for scalar values" ); |
172 | |
173 | for (int c = 0; c < Cols; c++) { |
174 | const int lhs_c = LhsCols == Cols ? c : 0; |
175 | const int rhs_c = RhsCols == Cols ? c : 0; |
176 | for (int r = 0; r < Rows; r++) { |
177 | const int lhs_r = LhsRows == Rows ? r : 0; |
178 | const int rhs_r = RhsRows == Rows ? r : 0; |
179 | result.buf.reg[r + c * Rows] = |
180 | Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows], |
181 | rhs.buf.reg[rhs_r + rhs_c * RhsRows]); |
182 | } |
183 | } |
184 | return result; |
185 | } |
186 | }; |
187 | |
188 | template <typename Lhs, typename Rhs> |
189 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd( |
190 | const Lhs& lhs, const Rhs& rhs) { |
191 | using Flip = FlipLhsRhs<Lhs, Rhs>; |
192 | return BroadcastAddImpl< |
193 | typename Flip::FlippedLhsType, |
194 | typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), |
195 | Flip::FlippedRhs(lhs, rhs)); |
196 | } |
197 | |
198 | template <typename Lhs, typename Rhs> |
199 | struct BroadcastShiftLeftImpl { |
200 | using ResultBlockType = |
201 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; |
202 | static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { |
203 | ResultBlockType result; |
204 | static constexpr int Rows = ResultBlockType::kRows; |
205 | static constexpr int Cols = ResultBlockType::kCols; |
206 | static constexpr int LhsRows = Lhs::kRows; |
207 | static constexpr int LhsCols = Lhs::kCols; |
208 | static constexpr int RhsRows = Rhs::kRows; |
209 | static constexpr int RhsCols = Rhs::kCols; |
210 | |
211 | static_assert(LhsRows == Rows || LhsRows == 1, "" ); |
212 | static_assert(RhsRows == Rows || RhsRows == 1, "" ); |
213 | static_assert(LhsCols == Cols || LhsCols == 1, "" ); |
214 | static_assert(RhsCols == Cols || RhsCols == 1, "" ); |
215 | static_assert(ResultBlockType::kRegisterLanes == 1, |
216 | "This path is only for scalar values" ); |
217 | static_assert(Lhs::kRegisterLanes == 1, |
218 | "This path is only for scalar values" ); |
219 | static_assert(Rhs::kRegisterLanes == 1, |
220 | "This path is only for scalar values" ); |
221 | |
222 | for (int c = 0; c < Cols; c++) { |
223 | const int lhs_c = LhsCols == Cols ? c : 0; |
224 | const int rhs_c = RhsCols == Cols ? c : 0; |
225 | for (int r = 0; r < Rows; r++) { |
226 | const int lhs_r = LhsRows == Rows ? r : 0; |
227 | const int rhs_r = RhsRows == Rows ? r : 0; |
228 | result.buf.reg[r + c * Rows] = |
229 | ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows], |
230 | rhs.buf.reg[rhs_r + rhs_c * RhsRows]); |
231 | } |
232 | } |
233 | return result; |
234 | } |
235 | }; |
236 | |
237 | template <typename Lhs, typename Rhs> |
238 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft( |
239 | const Lhs& lhs, const Rhs& rhs) { |
240 | using Flip = FlipLhsRhs<Lhs, Rhs>; |
241 | return BroadcastShiftLeftImpl< |
242 | typename Flip::FlippedLhsType, |
243 | typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), |
244 | Flip::FlippedRhs(lhs, rhs)); |
245 | } |
246 | |
247 | template <typename Lhs, typename Rhs> |
248 | struct BroadcastSaturatingRoundingDoublingHighMulImpl { |
249 | using ResultBlockType = |
250 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; |
251 | static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { |
252 | ResultBlockType result; |
253 | static constexpr int Rows = ResultBlockType::kRows; |
254 | static constexpr int Cols = ResultBlockType::kCols; |
255 | static constexpr int LhsRows = Lhs::kRows; |
256 | static constexpr int LhsCols = Lhs::kCols; |
257 | static constexpr int RhsRows = Rhs::kRows; |
258 | static constexpr int RhsCols = Rhs::kCols; |
259 | |
260 | static_assert(LhsRows == Rows || LhsRows == 1, "" ); |
261 | static_assert(RhsRows == Rows || RhsRows == 1, "" ); |
262 | static_assert(LhsCols == Cols || LhsCols == 1, "" ); |
263 | static_assert(RhsCols == Cols || RhsCols == 1, "" ); |
264 | static_assert(ResultBlockType::kRegisterLanes == 1, |
265 | "This path is only for scalar values" ); |
266 | static_assert(Lhs::kRegisterLanes == 1, |
267 | "This path is only for scalar values" ); |
268 | static_assert(Rhs::kRegisterLanes == 1, |
269 | "This path is only for scalar values" ); |
270 | |
271 | for (int c = 0; c < Cols; c++) { |
272 | const int lhs_c = LhsCols == Cols ? c : 0; |
273 | const int rhs_c = RhsCols == Cols ? c : 0; |
274 | for (int r = 0; r < Rows; r++) { |
275 | const int lhs_r = LhsRows == Rows ? r : 0; |
276 | const int rhs_r = RhsRows == Rows ? r : 0; |
277 | result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul( |
278 | lhs.buf.reg[lhs_r + lhs_c * LhsRows], |
279 | rhs.buf.reg[rhs_r + rhs_c * RhsRows]); |
280 | } |
281 | } |
282 | return result; |
283 | } |
284 | }; |
285 | |
286 | template <typename Lhs, typename Rhs> |
287 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type |
288 | BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) { |
289 | using Flip = FlipLhsRhs<Lhs, Rhs>; |
290 | return BroadcastSaturatingRoundingDoublingHighMulImpl< |
291 | typename Flip::FlippedLhsType, |
292 | typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), |
293 | Flip::FlippedRhs(lhs, rhs)); |
294 | } |
295 | |
296 | template <typename Lhs, typename Rhs> |
297 | struct BroadcastRoundingDivideByPOTImpl { |
298 | using ResultBlockType = |
299 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; |
300 | static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { |
301 | ResultBlockType result; |
302 | static constexpr int Rows = ResultBlockType::kRows; |
303 | static constexpr int Cols = ResultBlockType::kCols; |
304 | static constexpr int LhsRows = Lhs::kRows; |
305 | static constexpr int LhsCols = Lhs::kCols; |
306 | static constexpr int RhsRows = Rhs::kRows; |
307 | static constexpr int RhsCols = Rhs::kCols; |
308 | |
309 | static_assert(LhsRows == Rows || LhsRows == 1, "" ); |
310 | static_assert(RhsRows == Rows || RhsRows == 1, "" ); |
311 | static_assert(LhsCols == Cols || LhsCols == 1, "" ); |
312 | static_assert(RhsCols == Cols || RhsCols == 1, "" ); |
313 | static_assert(ResultBlockType::kRegisterLanes == 1, |
314 | "This path is only for scalar values" ); |
315 | static_assert(Lhs::kRegisterLanes == 1, |
316 | "This path is only for scalar values" ); |
317 | static_assert(Rhs::kRegisterLanes == 1, |
318 | "This path is only for scalar values" ); |
319 | |
320 | for (int c = 0; c < Cols; c++) { |
321 | const int lhs_c = LhsCols == Cols ? c : 0; |
322 | const int rhs_c = RhsCols == Cols ? c : 0; |
323 | for (int r = 0; r < Rows; r++) { |
324 | const int lhs_r = LhsRows == Rows ? r : 0; |
325 | const int rhs_r = RhsRows == Rows ? r : 0; |
326 | result.buf.reg[r + c * Rows] = |
327 | RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows], |
328 | rhs.buf.reg[rhs_r + rhs_c * RhsRows]); |
329 | } |
330 | } |
331 | return result; |
332 | } |
333 | }; |
334 | |
335 | template <typename Lhs, typename Rhs> |
336 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type |
337 | BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) { |
338 | using Flip = FlipLhsRhs<Lhs, Rhs>; |
339 | return BroadcastRoundingDivideByPOTImpl< |
340 | typename Flip::FlippedLhsType, |
341 | typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), |
342 | Flip::FlippedRhs(lhs, rhs)); |
343 | } |
344 | |
345 | template <typename Lhs, typename Rhs> |
346 | struct BroadcastMulImpl { |
347 | using ResultBlockType = |
348 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; |
349 | static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { |
350 | ResultBlockType result; |
351 | static constexpr int Rows = ResultBlockType::kRows; |
352 | static constexpr int Cols = ResultBlockType::kCols; |
353 | static constexpr int LhsRows = Lhs::kRows; |
354 | static constexpr int LhsCols = Lhs::kCols; |
355 | static constexpr int RhsRows = Rhs::kRows; |
356 | static constexpr int RhsCols = Rhs::kCols; |
357 | static_assert(ResultBlockType::kRegisterLanes == 1, |
358 | "This path is only for scalar values" ); |
359 | static_assert(Lhs::kRegisterLanes == 1, |
360 | "This path is only for scalar values" ); |
361 | static_assert(Rhs::kRegisterLanes == 1, |
362 | "This path is only for scalar values" ); |
363 | |
364 | static_assert(LhsRows == Rows || LhsRows == 1, "" ); |
365 | static_assert(RhsRows == Rows || RhsRows == 1, "" ); |
366 | static_assert(LhsCols == Cols || LhsCols == 1, "" ); |
367 | static_assert(RhsCols == Cols || RhsCols == 1, "" ); |
368 | for (int c = 0; c < Cols; c++) { |
369 | const int lhs_c = LhsCols == Cols ? c : 0; |
370 | const int rhs_c = RhsCols == Cols ? c : 0; |
371 | for (int r = 0; r < Rows; r++) { |
372 | const int lhs_r = LhsRows == Rows ? r : 0; |
373 | const int rhs_r = RhsRows == Rows ? r : 0; |
374 | result.buf.reg[r + c * Rows] = |
375 | Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows], |
376 | rhs.buf.reg[rhs_r + rhs_c * RhsRows]); |
377 | } |
378 | } |
379 | return result; |
380 | } |
381 | }; |
382 | |
383 | template <typename Lhs, typename Rhs> |
384 | typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul( |
385 | const Lhs& lhs, const Rhs& rhs) { |
386 | using Flip = FlipLhsRhs<Lhs, Rhs>; |
387 | return BroadcastMulImpl< |
388 | typename Flip::FlippedLhsType, |
389 | typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), |
390 | Flip::FlippedRhs(lhs, rhs)); |
391 | } |
392 | |
393 | template <typename Lhs, typename Rhs, typename Acc> |
394 | struct BroadcastMulAddImpl { |
395 | static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) { |
396 | static constexpr int Rows = Acc::kRows; |
397 | static constexpr int Cols = Acc::kCols; |
398 | static constexpr int LhsRows = Lhs::kRows; |
399 | static constexpr int LhsCols = Lhs::kCols; |
400 | static constexpr int RhsRows = Rhs::kRows; |
401 | static constexpr int RhsCols = Rhs::kCols; |
402 | static_assert(Acc::kRegisterLanes == 1, |
403 | "This path is only for scalar values" ); |
404 | static_assert(Lhs::kRegisterLanes == 1, |
405 | "This path is only for scalar values" ); |
406 | static_assert(Rhs::kRegisterLanes == 1, |
407 | "This path is only for scalar values" ); |
408 | |
409 | static_assert(LhsRows == Rows || LhsRows == 1, "" ); |
410 | static_assert(RhsRows == Rows || RhsRows == 1, "" ); |
411 | static_assert(LhsCols == Cols || LhsCols == 1, "" ); |
412 | static_assert(RhsCols == Cols || RhsCols == 1, "" ); |
413 | for (int c = 0; c < Cols; c++) { |
414 | const int lhs_c = LhsCols == Cols ? c : 0; |
415 | const int rhs_c = RhsCols == Cols ? c : 0; |
416 | for (int r = 0; r < Rows; r++) { |
417 | const int lhs_r = LhsRows == Rows ? r : 0; |
418 | const int rhs_r = RhsRows == Rows ? r : 0; |
419 | MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows], |
420 | rhs.buf.reg[rhs_r + rhs_c * RhsRows], |
421 | &acc->buf.reg[r + c * Rows]); |
422 | } |
423 | } |
424 | } |
425 | }; |
426 | |
427 | template <typename Lhs, typename Rhs, typename Acc> |
428 | void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) { |
429 | using Flip = FlipLhsRhs<Lhs, Rhs>; |
430 | BroadcastMulAddImpl<typename Flip::FlippedLhsType, |
431 | typename Flip::FlippedRhsType, |
432 | Acc>::Run(Flip::FlippedLhs(lhs, rhs), |
433 | Flip::FlippedRhs(lhs, rhs), acc); |
434 | } |
435 | |
436 | template <typename RegisterBlockType, typename SrcObjectType> |
437 | struct LoadImpl { |
438 | static_assert(std::is_same<SrcObjectType, void>::value, |
439 | "This generic impl should never be hit" ); |
440 | }; |
441 | |
442 | template <typename ScalarType, int Rows, int Cols, typename SrcScalarType> |
443 | struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, |
444 | MatrixMap<SrcScalarType, MapOrder::ColMajor>> { |
445 | using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; |
446 | using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>; |
447 | static RegisterBlockType Run(const SrcObjectType& src, int row, int col) { |
448 | RegisterBlockType result; |
449 | int i = 0; |
450 | for (int c = 0; c < Cols; c++) { |
451 | const ScalarType* src_ptr = src.data(row, col + c); |
452 | for (int r = 0; r < Rows; r++) { |
453 | result.buf.reg[i++] = *src_ptr++; |
454 | } |
455 | } |
456 | return result; |
457 | } |
458 | }; |
459 | |
460 | template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, |
461 | VectorShape Shape> |
462 | struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, |
463 | VectorMap<SrcScalarType, Shape>> { |
464 | using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; |
465 | using SrcObjectType = VectorMap<SrcScalarType, Shape>; |
466 | static RegisterBlockType Run(const SrcObjectType& src, int pos) { |
467 | static_assert(Shape == VectorShape::Col || Rows == 1, "" ); |
468 | static_assert(Shape == VectorShape::Row || Cols == 1, "" ); |
469 | RegisterBlockType result; |
470 | for (int i = 0; i < Rows * Cols; i++) { |
471 | result.buf.reg[i] = src(pos + i); |
472 | } |
473 | return result; |
474 | } |
475 | }; |
476 | |
477 | template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, |
478 | VectorShape Shape> |
479 | struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, |
480 | VectorDup<SrcScalarType, Shape>> { |
481 | using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; |
482 | using SrcObjectType = VectorDup<SrcScalarType, Shape>; |
483 | static RegisterBlockType Run(const SrcObjectType& src, int) { |
484 | static_assert(Shape == VectorShape::Col || Rows == 1, "" ); |
485 | static_assert(Shape == VectorShape::Row || Cols == 1, "" ); |
486 | RegisterBlockType result; |
487 | for (int i = 0; i < Rows * Cols; i++) { |
488 | result.buf.reg[i] = src(0); |
489 | } |
490 | return result; |
491 | } |
492 | }; |
493 | |
494 | template <typename RegisterBlockType, typename SrcObjectType> |
495 | RegisterBlockType Load(const SrcObjectType& src, int row, int col) { |
496 | return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col); |
497 | } |
498 | |
499 | template <typename RegisterBlockType, typename SrcObjectType> |
500 | RegisterBlockType Load(const SrcObjectType& src, int pos) { |
501 | return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos); |
502 | } |
503 | |
504 | template <typename RegisterBlockType> |
505 | struct LoadContiguousImpl { |
506 | using ScalarType = typename RegisterBlockType::ScalarType; |
507 | static_assert(RegisterBlockType::kRegisterLanes == 1, |
508 | "This path is only for scalar values" ); |
509 | static RegisterBlockType Run(const ScalarType* src) { |
510 | RegisterBlockType result; |
511 | for (int i = 0; i < RegisterBlockType::kScalarCount; i++) { |
512 | result.buf.reg[i] = src[i]; |
513 | } |
514 | return result; |
515 | } |
516 | }; |
517 | |
518 | template <typename RegisterBlockType> |
519 | RegisterBlockType LoadContiguous( |
520 | const typename RegisterBlockType::ScalarType* src) { |
521 | return LoadContiguousImpl<RegisterBlockType>::Run(src); |
522 | } |
523 | |
524 | template <int BroadcastRows, int BroadcastCols, typename SrcObjectType> |
525 | struct LoadForBroadcastingShape {}; |
526 | |
527 | template <int BroadcastRows, int BroadcastCols, typename ScalarType, |
528 | VectorShape Shape> |
529 | struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, |
530 | VectorMap<ScalarType, Shape>> { |
531 | static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1; |
532 | static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1; |
533 | }; |
534 | |
535 | template <int BroadcastRows, int BroadcastCols, typename ScalarType, |
536 | VectorShape Shape> |
537 | struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, |
538 | VectorDup<ScalarType, Shape>> { |
539 | static constexpr int kRows = 1; |
540 | static constexpr int kCols = 1; |
541 | }; |
542 | |
543 | template <typename RegisterBlockType, typename SrcObjectType> |
544 | struct LoadForBroadcastingRegisterBlock { |
545 | using Shape = |
546 | LoadForBroadcastingShape<RegisterBlockType::kRows, |
547 | RegisterBlockType::kCols, SrcObjectType>; |
548 | using ScalarType = typename RegisterBlockType::ScalarType; |
549 | using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; |
550 | }; |
551 | |
552 | template <typename RegisterBlockType, typename SrcObjectType> |
553 | struct LoadForBroadcastingImpl { |
554 | static_assert(std::is_same<SrcObjectType, void>::value, |
555 | "This generic impl should never be hit" ); |
556 | }; |
557 | |
558 | template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, |
559 | VectorShape Shape> |
560 | struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, |
561 | VectorMap<SrcScalarType, Shape>> { |
562 | using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; |
563 | using SrcObjectType = VectorMap<SrcScalarType, Shape>; |
564 | using ResultBlockType = |
565 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
566 | SrcObjectType>::Type; |
567 | static_assert(ResultBlockType::kRegisterLanes == 1, |
568 | "This path is only for scalar values" ); |
569 | static ResultBlockType Run(const SrcObjectType& src, int pos) { |
570 | ResultBlockType result; |
571 | for (int c = 0; c < ResultBlockType::kCols; c++) { |
572 | for (int r = 0; r < ResultBlockType::kRows; r++) { |
573 | const int i = Shape == VectorShape::Col ? r : c; |
574 | result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i); |
575 | } |
576 | } |
577 | return result; |
578 | } |
579 | }; |
580 | |
581 | template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, |
582 | VectorShape Shape> |
583 | struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, |
584 | VectorDup<SrcScalarType, Shape>> { |
585 | using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; |
586 | using SrcObjectType = VectorDup<SrcScalarType, Shape>; |
587 | using ResultBlockType = |
588 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
589 | SrcObjectType>::Type; |
590 | static_assert(ResultBlockType::kRegisterLanes == 1, |
591 | "This path is only for scalar values" ); |
592 | static ResultBlockType Run(const SrcObjectType& src, int) { |
593 | ResultBlockType result; |
594 | for (int c = 0; c < ResultBlockType::kCols; c++) { |
595 | for (int r = 0; r < ResultBlockType::kRows; r++) { |
596 | result.buf.reg[r + c * ResultBlockType::kRows] = src(0); |
597 | } |
598 | } |
599 | return result; |
600 | } |
601 | }; |
602 | |
603 | template <typename RegisterBlockType, typename SrcObjectType> |
604 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
605 | SrcObjectType>::Type |
606 | LoadForBroadcasting(const SrcObjectType& src, int row, int col) { |
607 | return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run( |
608 | src, row, col); |
609 | } |
610 | |
611 | template <typename RegisterBlockType, typename SrcObjectType> |
612 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
613 | SrcObjectType>::Type |
614 | LoadForBroadcasting(const SrcObjectType& src, int pos) { |
615 | return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src, |
616 | pos); |
617 | } |
618 | |
619 | template <int ConstantValue, typename RegisterBlockType> |
620 | struct AddConstantImpl { |
621 | static void Run(RegisterBlockType* block) { |
622 | using RegisterType = typename RegisterBlockType::RegisterType; |
623 | const RegisterType dup = Dup<RegisterType>(ConstantValue); |
624 | for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { |
625 | block->buf.reg[i] = Add(block->buf.reg[i], dup); |
626 | } |
627 | } |
628 | }; |
629 | |
630 | template <typename RegisterBlockType> |
631 | struct AddConstantImpl<0, RegisterBlockType> { |
632 | static void Run(RegisterBlockType*) { |
633 | // This is a no-op. |
634 | } |
635 | }; |
636 | |
637 | template <int ConstantValue, typename RegisterBlockType> |
638 | void AddConstant(RegisterBlockType* block) { |
639 | AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block); |
640 | } |
641 | |
642 | template <int N> |
643 | using RegBufferInt32 = RegisterBuffer<std::int32_t, N>; |
644 | template <int N> |
645 | using RegBufferInt16 = RegisterBuffer<std::int16_t, N>; |
646 | template <int N> |
647 | using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>; |
648 | template <int N> |
649 | using RegBufferInt8 = RegisterBuffer<std::int8_t, N>; |
650 | template <int R, int C> |
651 | using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>; |
652 | template <int R, int C> |
653 | using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>; |
654 | template <int R, int C> |
655 | using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>; |
656 | template <int R, int C> |
657 | using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>; |
658 | |
659 | } // end namespace gemmlowp |
660 | |
661 | #if defined GEMMLOWP_NEON |
662 | #include "simd_wrappers_neon.h" |
663 | #elif defined GEMMLOWP_SSE4 |
664 | #include "simd_wrappers_sse.h" |
665 | #elif defined GEMMLOWP_MSA |
666 | #include "simd_wrappers_msa.h" |
667 | #endif |
668 | |
669 | #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ |
670 | |