1// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
16// extending the set of such functions from fixedpoint.h.
17
18#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
19#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
20
21#include <algorithm>
22#include <type_traits>
23#include "../fixedpoint/fixedpoint.h"
24
25namespace gemmlowp {
26
27template <typename ScalarType, int ScalarCount>
28struct RegisterType {
29 using Type = ScalarType;
30};
31
32inline std::int32_t Min(std::int32_t a, std::int32_t b) {
33 return std::min(a, b);
34}
35
36inline std::int32_t Max(std::int32_t a, std::int32_t b) {
37 return std::max(a, b);
38}
39
40inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
41 *acc += lhs * rhs;
42}
43
44template <typename tScalarType, int tScalarCount>
45struct RegisterBuffer {
46 using ScalarType = tScalarType;
47 static constexpr int kScalarCount = tScalarCount;
48 using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
49 static_assert((kScalarCount & (kScalarCount - 1)) == 0,
50 "kScalarCount must be a power of two");
51 static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
52 static constexpr int kRegisterLanes =
53 sizeof(RegisterType) / sizeof(ScalarType);
54 static constexpr int kRegisterCount =
55 (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
56 sizeof(RegisterType);
57
58 RegisterType reg[kRegisterCount];
59};
60
61template <typename tScalarType, int tRows, int tCols>
62struct RegisterBlock {
63 using ScalarType = tScalarType;
64 static constexpr int kRows = tRows;
65 static constexpr int kCols = tCols;
66 static constexpr int kScalarCount = kRows * kCols;
67 using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
68 using RegisterType = typename BufferType::RegisterType;
69 static constexpr int kRegisterCount = BufferType::kRegisterCount;
70 static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
71
72 BufferType buf;
73};
74
75template <typename RegisterBlockType>
76struct RegisterBlockAddImpl {
77 static RegisterBlockType Run(const RegisterBlockType& lhs,
78 const RegisterBlockType& rhs) {
79 RegisterBlockType result;
80 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
81 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
82 }
83 return result;
84 }
85};
86
87template <typename RegisterBlockType>
88RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
89 const RegisterBlockType& rhs) {
90 return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
91}
92
93template <typename LhsType, typename RhsType>
94struct ShouldFlipLhsRhs {
95 static constexpr bool kValue =
96 (LhsType::kScalarCount < RhsType::kScalarCount) ||
97 (LhsType::kScalarCount == RhsType::kScalarCount &&
98 (LhsType::kRows < RhsType::kRows));
99};
100
101template <typename LhsType, typename RhsType,
102 bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
103struct FlipLhsRhs {
104 using FlippedLhsType = LhsType;
105 using FlippedRhsType = RhsType;
106 static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
107 const RhsType& rhs) {
108 (void)rhs;
109 return lhs;
110 }
111 static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
112 const RhsType& rhs) {
113 (void)lhs;
114 return rhs;
115 }
116};
117
118template <typename LhsType, typename RhsType>
119struct FlipLhsRhs<LhsType, RhsType, true> {
120 using FlippedLhsType = RhsType;
121 using FlippedRhsType = LhsType;
122 static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
123 const RhsType& rhs) {
124 (void)lhs;
125 return rhs;
126 }
127 static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
128 const RhsType& rhs) {
129 (void)rhs;
130 return lhs;
131 }
132};
133
134template <typename Lhs, typename Rhs>
135struct BroadcastBinaryOpShape {
136 static constexpr int kRows =
137 Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
138 static constexpr int kCols =
139 Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
140};
141
142template <typename Lhs, typename Rhs>
143struct BroadcastBinaryOpRegisterBlock {
144 using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
145 using ScalarType = typename Lhs::ScalarType;
146 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
147};
148
149template <typename Lhs, typename Rhs>
150struct BroadcastAddImpl {
151 using ResultBlockType =
152 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
153 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
154 ResultBlockType result;
155 static constexpr int Rows = ResultBlockType::kRows;
156 static constexpr int Cols = ResultBlockType::kCols;
157 static constexpr int LhsRows = Lhs::kRows;
158 static constexpr int LhsCols = Lhs::kCols;
159 static constexpr int RhsRows = Rhs::kRows;
160 static constexpr int RhsCols = Rhs::kCols;
161
162 static_assert(LhsRows == Rows || LhsRows == 1, "");
163 static_assert(RhsRows == Rows || RhsRows == 1, "");
164 static_assert(LhsCols == Cols || LhsCols == 1, "");
165 static_assert(RhsCols == Cols || RhsCols == 1, "");
166 static_assert(ResultBlockType::kRegisterLanes == 1,
167 "This path is only for scalar values");
168 static_assert(Lhs::kRegisterLanes == 1,
169 "This path is only for scalar values");
170 static_assert(Rhs::kRegisterLanes == 1,
171 "This path is only for scalar values");
172
173 for (int c = 0; c < Cols; c++) {
174 const int lhs_c = LhsCols == Cols ? c : 0;
175 const int rhs_c = RhsCols == Cols ? c : 0;
176 for (int r = 0; r < Rows; r++) {
177 const int lhs_r = LhsRows == Rows ? r : 0;
178 const int rhs_r = RhsRows == Rows ? r : 0;
179 result.buf.reg[r + c * Rows] =
180 Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
181 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
182 }
183 }
184 return result;
185 }
186};
187
188template <typename Lhs, typename Rhs>
189typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
190 const Lhs& lhs, const Rhs& rhs) {
191 using Flip = FlipLhsRhs<Lhs, Rhs>;
192 return BroadcastAddImpl<
193 typename Flip::FlippedLhsType,
194 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
195 Flip::FlippedRhs(lhs, rhs));
196}
197
198template <typename Lhs, typename Rhs>
199struct BroadcastShiftLeftImpl {
200 using ResultBlockType =
201 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
202 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
203 ResultBlockType result;
204 static constexpr int Rows = ResultBlockType::kRows;
205 static constexpr int Cols = ResultBlockType::kCols;
206 static constexpr int LhsRows = Lhs::kRows;
207 static constexpr int LhsCols = Lhs::kCols;
208 static constexpr int RhsRows = Rhs::kRows;
209 static constexpr int RhsCols = Rhs::kCols;
210
211 static_assert(LhsRows == Rows || LhsRows == 1, "");
212 static_assert(RhsRows == Rows || RhsRows == 1, "");
213 static_assert(LhsCols == Cols || LhsCols == 1, "");
214 static_assert(RhsCols == Cols || RhsCols == 1, "");
215 static_assert(ResultBlockType::kRegisterLanes == 1,
216 "This path is only for scalar values");
217 static_assert(Lhs::kRegisterLanes == 1,
218 "This path is only for scalar values");
219 static_assert(Rhs::kRegisterLanes == 1,
220 "This path is only for scalar values");
221
222 for (int c = 0; c < Cols; c++) {
223 const int lhs_c = LhsCols == Cols ? c : 0;
224 const int rhs_c = RhsCols == Cols ? c : 0;
225 for (int r = 0; r < Rows; r++) {
226 const int lhs_r = LhsRows == Rows ? r : 0;
227 const int rhs_r = RhsRows == Rows ? r : 0;
228 result.buf.reg[r + c * Rows] =
229 ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
230 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
231 }
232 }
233 return result;
234 }
235};
236
237template <typename Lhs, typename Rhs>
238typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
239 const Lhs& lhs, const Rhs& rhs) {
240 using Flip = FlipLhsRhs<Lhs, Rhs>;
241 return BroadcastShiftLeftImpl<
242 typename Flip::FlippedLhsType,
243 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
244 Flip::FlippedRhs(lhs, rhs));
245}
246
247template <typename Lhs, typename Rhs>
248struct BroadcastSaturatingRoundingDoublingHighMulImpl {
249 using ResultBlockType =
250 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
251 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
252 ResultBlockType result;
253 static constexpr int Rows = ResultBlockType::kRows;
254 static constexpr int Cols = ResultBlockType::kCols;
255 static constexpr int LhsRows = Lhs::kRows;
256 static constexpr int LhsCols = Lhs::kCols;
257 static constexpr int RhsRows = Rhs::kRows;
258 static constexpr int RhsCols = Rhs::kCols;
259
260 static_assert(LhsRows == Rows || LhsRows == 1, "");
261 static_assert(RhsRows == Rows || RhsRows == 1, "");
262 static_assert(LhsCols == Cols || LhsCols == 1, "");
263 static_assert(RhsCols == Cols || RhsCols == 1, "");
264 static_assert(ResultBlockType::kRegisterLanes == 1,
265 "This path is only for scalar values");
266 static_assert(Lhs::kRegisterLanes == 1,
267 "This path is only for scalar values");
268 static_assert(Rhs::kRegisterLanes == 1,
269 "This path is only for scalar values");
270
271 for (int c = 0; c < Cols; c++) {
272 const int lhs_c = LhsCols == Cols ? c : 0;
273 const int rhs_c = RhsCols == Cols ? c : 0;
274 for (int r = 0; r < Rows; r++) {
275 const int lhs_r = LhsRows == Rows ? r : 0;
276 const int rhs_r = RhsRows == Rows ? r : 0;
277 result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
278 lhs.buf.reg[lhs_r + lhs_c * LhsRows],
279 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
280 }
281 }
282 return result;
283 }
284};
285
286template <typename Lhs, typename Rhs>
287typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
288BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
289 using Flip = FlipLhsRhs<Lhs, Rhs>;
290 return BroadcastSaturatingRoundingDoublingHighMulImpl<
291 typename Flip::FlippedLhsType,
292 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
293 Flip::FlippedRhs(lhs, rhs));
294}
295
296template <typename Lhs, typename Rhs>
297struct BroadcastRoundingDivideByPOTImpl {
298 using ResultBlockType =
299 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
300 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
301 ResultBlockType result;
302 static constexpr int Rows = ResultBlockType::kRows;
303 static constexpr int Cols = ResultBlockType::kCols;
304 static constexpr int LhsRows = Lhs::kRows;
305 static constexpr int LhsCols = Lhs::kCols;
306 static constexpr int RhsRows = Rhs::kRows;
307 static constexpr int RhsCols = Rhs::kCols;
308
309 static_assert(LhsRows == Rows || LhsRows == 1, "");
310 static_assert(RhsRows == Rows || RhsRows == 1, "");
311 static_assert(LhsCols == Cols || LhsCols == 1, "");
312 static_assert(RhsCols == Cols || RhsCols == 1, "");
313 static_assert(ResultBlockType::kRegisterLanes == 1,
314 "This path is only for scalar values");
315 static_assert(Lhs::kRegisterLanes == 1,
316 "This path is only for scalar values");
317 static_assert(Rhs::kRegisterLanes == 1,
318 "This path is only for scalar values");
319
320 for (int c = 0; c < Cols; c++) {
321 const int lhs_c = LhsCols == Cols ? c : 0;
322 const int rhs_c = RhsCols == Cols ? c : 0;
323 for (int r = 0; r < Rows; r++) {
324 const int lhs_r = LhsRows == Rows ? r : 0;
325 const int rhs_r = RhsRows == Rows ? r : 0;
326 result.buf.reg[r + c * Rows] =
327 RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
328 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
329 }
330 }
331 return result;
332 }
333};
334
335template <typename Lhs, typename Rhs>
336typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
337BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
338 using Flip = FlipLhsRhs<Lhs, Rhs>;
339 return BroadcastRoundingDivideByPOTImpl<
340 typename Flip::FlippedLhsType,
341 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
342 Flip::FlippedRhs(lhs, rhs));
343}
344
345template <typename Lhs, typename Rhs>
346struct BroadcastMulImpl {
347 using ResultBlockType =
348 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
349 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
350 ResultBlockType result;
351 static constexpr int Rows = ResultBlockType::kRows;
352 static constexpr int Cols = ResultBlockType::kCols;
353 static constexpr int LhsRows = Lhs::kRows;
354 static constexpr int LhsCols = Lhs::kCols;
355 static constexpr int RhsRows = Rhs::kRows;
356 static constexpr int RhsCols = Rhs::kCols;
357 static_assert(ResultBlockType::kRegisterLanes == 1,
358 "This path is only for scalar values");
359 static_assert(Lhs::kRegisterLanes == 1,
360 "This path is only for scalar values");
361 static_assert(Rhs::kRegisterLanes == 1,
362 "This path is only for scalar values");
363
364 static_assert(LhsRows == Rows || LhsRows == 1, "");
365 static_assert(RhsRows == Rows || RhsRows == 1, "");
366 static_assert(LhsCols == Cols || LhsCols == 1, "");
367 static_assert(RhsCols == Cols || RhsCols == 1, "");
368 for (int c = 0; c < Cols; c++) {
369 const int lhs_c = LhsCols == Cols ? c : 0;
370 const int rhs_c = RhsCols == Cols ? c : 0;
371 for (int r = 0; r < Rows; r++) {
372 const int lhs_r = LhsRows == Rows ? r : 0;
373 const int rhs_r = RhsRows == Rows ? r : 0;
374 result.buf.reg[r + c * Rows] =
375 Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
376 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
377 }
378 }
379 return result;
380 }
381};
382
383template <typename Lhs, typename Rhs>
384typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
385 const Lhs& lhs, const Rhs& rhs) {
386 using Flip = FlipLhsRhs<Lhs, Rhs>;
387 return BroadcastMulImpl<
388 typename Flip::FlippedLhsType,
389 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
390 Flip::FlippedRhs(lhs, rhs));
391}
392
393template <typename Lhs, typename Rhs, typename Acc>
394struct BroadcastMulAddImpl {
395 static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
396 static constexpr int Rows = Acc::kRows;
397 static constexpr int Cols = Acc::kCols;
398 static constexpr int LhsRows = Lhs::kRows;
399 static constexpr int LhsCols = Lhs::kCols;
400 static constexpr int RhsRows = Rhs::kRows;
401 static constexpr int RhsCols = Rhs::kCols;
402 static_assert(Acc::kRegisterLanes == 1,
403 "This path is only for scalar values");
404 static_assert(Lhs::kRegisterLanes == 1,
405 "This path is only for scalar values");
406 static_assert(Rhs::kRegisterLanes == 1,
407 "This path is only for scalar values");
408
409 static_assert(LhsRows == Rows || LhsRows == 1, "");
410 static_assert(RhsRows == Rows || RhsRows == 1, "");
411 static_assert(LhsCols == Cols || LhsCols == 1, "");
412 static_assert(RhsCols == Cols || RhsCols == 1, "");
413 for (int c = 0; c < Cols; c++) {
414 const int lhs_c = LhsCols == Cols ? c : 0;
415 const int rhs_c = RhsCols == Cols ? c : 0;
416 for (int r = 0; r < Rows; r++) {
417 const int lhs_r = LhsRows == Rows ? r : 0;
418 const int rhs_r = RhsRows == Rows ? r : 0;
419 MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
420 rhs.buf.reg[rhs_r + rhs_c * RhsRows],
421 &acc->buf.reg[r + c * Rows]);
422 }
423 }
424 }
425};
426
427template <typename Lhs, typename Rhs, typename Acc>
428void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
429 using Flip = FlipLhsRhs<Lhs, Rhs>;
430 BroadcastMulAddImpl<typename Flip::FlippedLhsType,
431 typename Flip::FlippedRhsType,
432 Acc>::Run(Flip::FlippedLhs(lhs, rhs),
433 Flip::FlippedRhs(lhs, rhs), acc);
434}
435
436template <typename RegisterBlockType, typename SrcObjectType>
437struct LoadImpl {
438 static_assert(std::is_same<SrcObjectType, void>::value,
439 "This generic impl should never be hit");
440};
441
442template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
443struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
444 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
445 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
446 using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
447 static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
448 RegisterBlockType result;
449 int i = 0;
450 for (int c = 0; c < Cols; c++) {
451 const ScalarType* src_ptr = src.data(row, col + c);
452 for (int r = 0; r < Rows; r++) {
453 result.buf.reg[i++] = *src_ptr++;
454 }
455 }
456 return result;
457 }
458};
459
460template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
461 VectorShape Shape>
462struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
463 VectorMap<SrcScalarType, Shape>> {
464 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
465 using SrcObjectType = VectorMap<SrcScalarType, Shape>;
466 static RegisterBlockType Run(const SrcObjectType& src, int pos) {
467 static_assert(Shape == VectorShape::Col || Rows == 1, "");
468 static_assert(Shape == VectorShape::Row || Cols == 1, "");
469 RegisterBlockType result;
470 for (int i = 0; i < Rows * Cols; i++) {
471 result.buf.reg[i] = src(pos + i);
472 }
473 return result;
474 }
475};
476
477template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
478 VectorShape Shape>
479struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
480 VectorDup<SrcScalarType, Shape>> {
481 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
482 using SrcObjectType = VectorDup<SrcScalarType, Shape>;
483 static RegisterBlockType Run(const SrcObjectType& src, int) {
484 static_assert(Shape == VectorShape::Col || Rows == 1, "");
485 static_assert(Shape == VectorShape::Row || Cols == 1, "");
486 RegisterBlockType result;
487 for (int i = 0; i < Rows * Cols; i++) {
488 result.buf.reg[i] = src(0);
489 }
490 return result;
491 }
492};
493
494template <typename RegisterBlockType, typename SrcObjectType>
495RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
496 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
497}
498
499template <typename RegisterBlockType, typename SrcObjectType>
500RegisterBlockType Load(const SrcObjectType& src, int pos) {
501 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
502}
503
504template <typename RegisterBlockType>
505struct LoadContiguousImpl {
506 using ScalarType = typename RegisterBlockType::ScalarType;
507 static_assert(RegisterBlockType::kRegisterLanes == 1,
508 "This path is only for scalar values");
509 static RegisterBlockType Run(const ScalarType* src) {
510 RegisterBlockType result;
511 for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
512 result.buf.reg[i] = src[i];
513 }
514 return result;
515 }
516};
517
518template <typename RegisterBlockType>
519RegisterBlockType LoadContiguous(
520 const typename RegisterBlockType::ScalarType* src) {
521 return LoadContiguousImpl<RegisterBlockType>::Run(src);
522}
523
524template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
525struct LoadForBroadcastingShape {};
526
527template <int BroadcastRows, int BroadcastCols, typename ScalarType,
528 VectorShape Shape>
529struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
530 VectorMap<ScalarType, Shape>> {
531 static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
532 static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
533};
534
535template <int BroadcastRows, int BroadcastCols, typename ScalarType,
536 VectorShape Shape>
537struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
538 VectorDup<ScalarType, Shape>> {
539 static constexpr int kRows = 1;
540 static constexpr int kCols = 1;
541};
542
543template <typename RegisterBlockType, typename SrcObjectType>
544struct LoadForBroadcastingRegisterBlock {
545 using Shape =
546 LoadForBroadcastingShape<RegisterBlockType::kRows,
547 RegisterBlockType::kCols, SrcObjectType>;
548 using ScalarType = typename RegisterBlockType::ScalarType;
549 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
550};
551
552template <typename RegisterBlockType, typename SrcObjectType>
553struct LoadForBroadcastingImpl {
554 static_assert(std::is_same<SrcObjectType, void>::value,
555 "This generic impl should never be hit");
556};
557
558template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
559 VectorShape Shape>
560struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
561 VectorMap<SrcScalarType, Shape>> {
562 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
563 using SrcObjectType = VectorMap<SrcScalarType, Shape>;
564 using ResultBlockType =
565 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
566 SrcObjectType>::Type;
567 static_assert(ResultBlockType::kRegisterLanes == 1,
568 "This path is only for scalar values");
569 static ResultBlockType Run(const SrcObjectType& src, int pos) {
570 ResultBlockType result;
571 for (int c = 0; c < ResultBlockType::kCols; c++) {
572 for (int r = 0; r < ResultBlockType::kRows; r++) {
573 const int i = Shape == VectorShape::Col ? r : c;
574 result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
575 }
576 }
577 return result;
578 }
579};
580
581template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
582 VectorShape Shape>
583struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
584 VectorDup<SrcScalarType, Shape>> {
585 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
586 using SrcObjectType = VectorDup<SrcScalarType, Shape>;
587 using ResultBlockType =
588 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
589 SrcObjectType>::Type;
590 static_assert(ResultBlockType::kRegisterLanes == 1,
591 "This path is only for scalar values");
592 static ResultBlockType Run(const SrcObjectType& src, int) {
593 ResultBlockType result;
594 for (int c = 0; c < ResultBlockType::kCols; c++) {
595 for (int r = 0; r < ResultBlockType::kRows; r++) {
596 result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
597 }
598 }
599 return result;
600 }
601};
602
603template <typename RegisterBlockType, typename SrcObjectType>
604typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
605 SrcObjectType>::Type
606LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
607 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
608 src, row, col);
609}
610
611template <typename RegisterBlockType, typename SrcObjectType>
612typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
613 SrcObjectType>::Type
614LoadForBroadcasting(const SrcObjectType& src, int pos) {
615 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
616 pos);
617}
618
619template <int ConstantValue, typename RegisterBlockType>
620struct AddConstantImpl {
621 static void Run(RegisterBlockType* block) {
622 using RegisterType = typename RegisterBlockType::RegisterType;
623 const RegisterType dup = Dup<RegisterType>(ConstantValue);
624 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
625 block->buf.reg[i] = Add(block->buf.reg[i], dup);
626 }
627 }
628};
629
630template <typename RegisterBlockType>
631struct AddConstantImpl<0, RegisterBlockType> {
632 static void Run(RegisterBlockType*) {
633 // This is a no-op.
634 }
635};
636
637template <int ConstantValue, typename RegisterBlockType>
638void AddConstant(RegisterBlockType* block) {
639 AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
640}
641
642template <int N>
643using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
644template <int N>
645using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
646template <int N>
647using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
648template <int N>
649using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
650template <int R, int C>
651using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
652template <int R, int C>
653using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
654template <int R, int C>
655using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
656template <int R, int C>
657using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
658
659} // end namespace gemmlowp
660
661#if defined GEMMLOWP_NEON
662#include "simd_wrappers_neon.h"
663#elif defined GEMMLOWP_SSE4
664#include "simd_wrappers_sse.h"
665#elif defined GEMMLOWP_MSA
666#include "simd_wrappers_msa.h"
667#endif
668
669#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
670