1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code
16
17#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
18#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
19
20#include "simd_wrappers.h"
21
22namespace gemmlowp {
23
24template <typename SrcScalarType, int N>
25struct LoadImpl<RegBlockInt32<4, N>,
26 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
27 static RegBlockInt32<4, N> Run(
28 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
29 int col) {
30 RegBlockInt32<4, N> result;
31 for (int i = 0; i < N; i++) {
32 result.buf.reg[i] = LoadInt32x4(src.data(row, col + i));
33 }
34 return result;
35 }
36};
37
38template <typename SrcScalarType, int N>
39struct LoadImpl<RegBlockInt32<8, N>,
40 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
41 static RegBlockInt32<8, N> Run(
42 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
43 int col) {
44 RegBlockInt32<8, N> result;
45 for (int i = 0; i < N; i++) {
46 result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i));
47 result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i));
48 }
49 return result;
50 }
51};
52
53template <typename SrcScalarType>
54struct LoadImpl<RegBlockInt32<1, 4>,
55 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
56 static RegBlockInt32<1, 4> Run(
57 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
58 int col) {
59 RegBlockInt32<1, 4> result;
60 std::int32_t buf[4];
61 for (int i = 0; i < 4; i++) {
62 buf[i] = src(row, col + i);
63 }
64 result.buf.reg[0] = LoadInt32x4(buf);
65 return result;
66 }
67};
68
69template <typename SrcScalarType>
70struct LoadImpl<RegBlockInt32<1, 8>,
71 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
72 static RegBlockInt32<1, 8> Run(
73 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
74 int col) {
75 RegBlockInt32<1, 8> result;
76 std::int32_t buf[8];
77 for (int i = 0; i < 8; i++) {
78 buf[i] = src(row, col + i);
79 }
80 result.buf.reg[0] = LoadInt32x4(buf);
81 result.buf.reg[1] = LoadInt32x4(buf + 4);
82 return result;
83 }
84};
85
86template <typename SrcScalarType>
87struct LoadImpl<RegBlockInt32<4, 1>,
88 VectorMap<SrcScalarType, VectorShape::Col>> {
89 static RegBlockInt32<4, 1> Run(
90 const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) {
91 RegBlockInt32<4, 1> result;
92 result.buf.reg[0] = LoadInt32x4(src.data(pos));
93 return result;
94 }
95};
96
97template <typename SrcScalarType>
98struct LoadImpl<RegBlockInt32<4, 1>,
99 VectorDup<SrcScalarType, VectorShape::Col>> {
100 static RegBlockInt32<4, 1> Run(
101 const VectorDup<SrcScalarType, VectorShape::Col>& src, int) {
102 RegBlockInt32<4, 1> result;
103 result.buf.reg[0] = LoadInt32x4(src(0));
104 return result;
105 }
106};
107
108template <typename SrcScalarType, int N>
109struct LoadForBroadcastingImpl<RegBlockInt32<4, N>,
110 VectorMap<SrcScalarType, VectorShape::Col>> {
111 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
112 using RegisterBlockType = RegBlockInt32<4, N>;
113 using ResultBlockType =
114 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
115 SrcObjectType>::Type;
116
117 static ResultBlockType Run(const SrcObjectType& src, int pos) {
118 ResultBlockType result;
119 static_assert(ResultBlockType::kRegisterCount == 1, "");
120 result.buf.reg[0] = LoadInt32x4(src.data(pos));
121 return result;
122 }
123};
124
125template <typename SrcScalarType, int N>
126struct LoadForBroadcastingImpl<RegBlockInt32<8, N>,
127 VectorMap<SrcScalarType, VectorShape::Col>> {
128 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
129 using RegisterBlockType = RegBlockInt32<8, N>;
130 using ResultBlockType =
131 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
132 SrcObjectType>::Type;
133
134 static ResultBlockType Run(const SrcObjectType& src, int pos) {
135 ResultBlockType result;
136 static_assert(ResultBlockType::kRegisterCount == 2, "");
137 result.buf.reg[0] = LoadInt32x4(src.data(pos));
138 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
139 return result;
140 }
141};
142
143template <typename SrcScalarType>
144struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>,
145 VectorMap<SrcScalarType, VectorShape::Row>> {
146 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
147 using RegisterBlockType = RegBlockInt32<4, 1>;
148 using ResultBlockType =
149 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
150 SrcObjectType>::Type;
151
152 static ResultBlockType Run(const SrcObjectType& src, int pos) {
153 ResultBlockType result;
154 result.buf.reg[0] = src(pos);
155 return result;
156 }
157};
158
159template <typename SrcScalarType, int N>
160struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>,
161 VectorMap<SrcScalarType, VectorShape::Row>> {
162 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
163 using RegisterBlockType = RegBlockInt32<N, 4>;
164 using ResultBlockType =
165 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
166 SrcObjectType>::Type;
167
168 static ResultBlockType Run(const SrcObjectType& src, int pos) {
169 ResultBlockType result;
170 static_assert(ResultBlockType::kRegisterCount == 1, "");
171 result.buf.reg[0] = LoadInt32x4(src.data(pos));
172 return result;
173 }
174};
175
176template <typename SrcScalarType, int N>
177struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>,
178 VectorMap<SrcScalarType, VectorShape::Row>> {
179 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
180 using RegisterBlockType = RegBlockInt32<N, 8>;
181 using ResultBlockType =
182 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
183 SrcObjectType>::Type;
184
185 static ResultBlockType Run(const SrcObjectType& src, int pos) {
186 ResultBlockType result;
187 static_assert(ResultBlockType::kRegisterCount == 2, "");
188 result.buf.reg[0] = LoadInt32x4(src.data(pos));
189 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
190 return result;
191 }
192};
193
194// 4x1 := 4x1 + 1x1
195template <>
196struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
197 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
198 const RegBlockInt32<1, 1>& rhs) {
199 RegBlockInt32<4, 1> result;
200 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
201 return result;
202 }
203};
204
205// 1x4 := 1x4 + 1x1
206template <>
207struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
208 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
209 const RegBlockInt32<1, 1>& rhs) {
210 RegBlockInt32<1, 4> result;
211 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
212 return result;
213 }
214};
215
216// 4x1 := 4x1 + 4x1
217template <>
218struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
219 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
220 const RegBlockInt32<4, 1>& rhs) {
221 RegBlockInt32<4, 1> result;
222 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
223 return result;
224 }
225};
226
227// 1x4 := 1x4 + 1x4
228template <>
229struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
230 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
231 const RegBlockInt32<1, 4>& rhs) {
232 RegBlockInt32<1, 4> result;
233 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
234 return result;
235 }
236};
237
238// 4x4 := 4x4 + 1x4
239template <>
240struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
241 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
242 const RegBlockInt32<1, 4>& rhs) {
243 RegBlockInt32<4, 4> result;
244 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
245 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
246 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
247 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
248 return result;
249 }
250};
251
252// 4x4 := 4x4 + 4x1
253template <>
254struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
255 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
256 const RegBlockInt32<4, 1>& rhs) {
257 RegBlockInt32<4, 4> result;
258 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
259 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]);
260 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
261 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]);
262 return result;
263 }
264};
265
266// 8x1 := 8x1 + 1x1
267template <>
268struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
269 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
270 const RegBlockInt32<1, 1>& rhs) {
271 RegBlockInt32<8, 1> result;
272 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
273 for (int i = 0; i < 2; i++) {
274 result.buf.reg[i] = Add(lhs.buf.reg[i], p);
275 }
276 return result;
277 }
278};
279
280// 8x1 := 8x1 + 8x1
281template <>
282struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
283 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
284 const RegBlockInt32<8, 1>& rhs) {
285 RegBlockInt32<8, 1> result;
286 for (int i = 0; i < 2; i++) {
287 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
288 }
289 return result;
290 }
291};
292
293// 8x4 := 8x4 + 1x4
294template <>
295struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
296 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
297 const RegBlockInt32<1, 4>& rhs) {
298 RegBlockInt32<8, 4> result;
299 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
300 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
301 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
302 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
303 result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
304 result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
305 result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
306 result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
307 return result;
308 }
309};
310
311// 8x4 := 8x4 + 8x1
312template <>
313struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
314 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
315 const RegBlockInt32<8, 1>& rhs) {
316 RegBlockInt32<8, 4> result;
317 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
318 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
319 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
320 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]);
321 result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]);
322 result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]);
323 result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]);
324 result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]);
325 return result;
326 }
327};
328
329// 1x8 := 1x8 + 1x8
330template <>
331struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
332 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
333 const RegBlockInt32<1, 8>& rhs) {
334 RegBlockInt32<1, 8> result;
335 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
336 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
337 return result;
338 }
339};
340
341// 1x8 := 1x8 + 1x1
342template <>
343struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
344 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
345 const RegBlockInt32<1, 1>& rhs) {
346 RegBlockInt32<1, 8> result;
347 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
348 result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
349 return result;
350 }
351};
352
353// 4x1 := 4x1 + 1x1
354template <>
355struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
356 RegBlockInt32<1, 1>> {
357 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
358 const RegBlockInt32<1, 1>& rhs) {
359 RegBlockInt32<4, 1> result;
360 result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
361 lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
362 return result;
363 }
364};
365
366// 1x4 := 1x4 + 1x1
367template <>
368struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
369 RegBlockInt32<1, 1>> {
370 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
371 const RegBlockInt32<1, 1>& rhs) {
372 RegBlockInt32<1, 4> result;
373 result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
374 lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
375 return result;
376 }
377};
378
379// 4x1 := 4x1 + 4x1
380template <>
381struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
382 RegBlockInt32<4, 1>> {
383 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
384 const RegBlockInt32<4, 1>& rhs) {
385 RegBlockInt32<4, 1> result;
386 result.buf.reg[0] =
387 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
388 return result;
389 }
390};
391
392// 1x4 := 1x4 + 1x4
393template <>
394struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
395 RegBlockInt32<1, 4>> {
396 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
397 const RegBlockInt32<1, 4>& rhs) {
398 RegBlockInt32<1, 4> result;
399 result.buf.reg[0] =
400 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
401 return result;
402 }
403};
404
405// 4x4 := 4x4 + 1x4
406template <>
407struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
408 RegBlockInt32<1, 4>> {
409 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
410 const RegBlockInt32<1, 4>& rhs) {
411 RegBlockInt32<4, 4> result;
412 result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
413 lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
414 result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
415 lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
416 result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
417 lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
418 result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
419 lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
420 return result;
421 }
422};
423
424// 4x4 := 4x4 + 4x1
425template <>
426struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
427 RegBlockInt32<4, 1>> {
428 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
429 const RegBlockInt32<4, 1>& rhs) {
430 RegBlockInt32<4, 4> result;
431 result.buf.reg[0] =
432 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
433 result.buf.reg[1] =
434 SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]);
435 result.buf.reg[2] =
436 SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
437 result.buf.reg[3] =
438 SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]);
439 return result;
440 }
441};
442
443// 8x1 := 8x1 + 1x1
444template <>
445struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
446 RegBlockInt32<1, 1>> {
447 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
448 const RegBlockInt32<1, 1>& rhs) {
449 RegBlockInt32<8, 1> result;
450 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
451 for (int i = 0; i < 2; i++) {
452 result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p);
453 }
454 return result;
455 }
456};
457
458// 8x1 := 8x1 + 8x1
459template <>
460struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
461 RegBlockInt32<8, 1>> {
462 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
463 const RegBlockInt32<8, 1>& rhs) {
464 RegBlockInt32<8, 1> result;
465 for (int i = 0; i < 2; i++) {
466 result.buf.reg[i] =
467 SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]);
468 }
469 return result;
470 }
471};
472
473// 8x4 := 8x4 + 1x4
474template <>
475struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
476 RegBlockInt32<1, 4>> {
477 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
478 const RegBlockInt32<1, 4>& rhs) {
479 RegBlockInt32<8, 4> result;
480 result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
481 lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
482 result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
483 lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
484 result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
485 lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
486 result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
487 lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
488 result.buf.reg[4] = SaturatingRoundingDoublingHighMul(
489 lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
490 result.buf.reg[5] = SaturatingRoundingDoublingHighMul(
491 lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
492 result.buf.reg[6] = SaturatingRoundingDoublingHighMul(
493 lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
494 result.buf.reg[7] = SaturatingRoundingDoublingHighMul(
495 lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
496 return result;
497 }
498};
499
500// 8x4 := 8x4 + 8x1
501template <>
502struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
503 RegBlockInt32<8, 1>> {
504 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
505 const RegBlockInt32<8, 1>& rhs) {
506 RegBlockInt32<8, 4> result;
507 result.buf.reg[0] =
508 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
509 result.buf.reg[1] =
510 SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
511 result.buf.reg[2] =
512 SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
513 result.buf.reg[3] =
514 SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]);
515 result.buf.reg[4] =
516 SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]);
517 result.buf.reg[5] =
518 SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]);
519 result.buf.reg[6] =
520 SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]);
521 result.buf.reg[7] =
522 SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]);
523 return result;
524 }
525};
526
527// 1x8 := 1x8 + 1x8
528template <>
529struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
530 RegBlockInt32<1, 8>> {
531 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
532 const RegBlockInt32<1, 8>& rhs) {
533 RegBlockInt32<1, 8> result;
534 result.buf.reg[0] =
535 SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
536 result.buf.reg[1] =
537 SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
538 return result;
539 }
540};
541
542// 1x8 := 1x8 + 1x1
543template <>
544struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
545 RegBlockInt32<1, 1>> {
546 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
547 const RegBlockInt32<1, 1>& rhs) {
548 RegBlockInt32<1, 8> result;
549 result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
550 lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
551 result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
552 lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
553 return result;
554 }
555};
556
557// 4x1 := 4x1 * 1x1
558template <>
559struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
560 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
561 const RegBlockInt32<1, 1>& rhs) {
562 RegBlockInt32<4, 1> result;
563 result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
564 return result;
565 }
566};
567
568// 4x1 := 4x1 * 4x1
569template <>
570struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
571 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
572 const RegBlockInt32<4, 1>& rhs) {
573 RegBlockInt32<4, 1> result;
574 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
575 return result;
576 }
577};
578
579// 1x4 := 1x4 * 1x4
580template <>
581struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
582 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
583 const RegBlockInt32<1, 4>& rhs) {
584 RegBlockInt32<1, 4> result;
585 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
586 return result;
587 }
588};
589
590// 1x4 := 1x4 * 1x1
591template <>
592struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
593 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
594 const RegBlockInt32<1, 1>& rhs) {
595 RegBlockInt32<1, 4> result;
596 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
597 return result;
598 }
599};
600
601// 4x4 := 4x4 * 1x4
602template <>
603struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
604 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
605 const RegBlockInt32<1, 4>& rhs) {
606 RegBlockInt32<4, 4> result;
607 const Int32x4 p = rhs.buf.reg[0];
608 result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p);
609 result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p);
610 result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p);
611 result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p);
612 return result;
613 }
614};
615
616// 4x4 := 4x4 * 4x1
617template <>
618struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
619 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
620 const RegBlockInt32<4, 1>& rhs) {
621 RegBlockInt32<4, 4> result;
622 const Int32x4 p = rhs.buf.reg[0];
623 result.buf.reg[0] = Mul(lhs.buf.reg[0], p);
624 result.buf.reg[1] = Mul(lhs.buf.reg[1], p);
625 result.buf.reg[2] = Mul(lhs.buf.reg[2], p);
626 result.buf.reg[3] = Mul(lhs.buf.reg[3], p);
627 return result;
628 }
629};
630
631// 8x1 := 8x1 * 1x1
632template <>
633struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
634 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
635 const RegBlockInt32<1, 1>& rhs) {
636 RegBlockInt32<8, 1> result;
637 const std::int32_t p = rhs.buf.reg[0];
638 for (int i = 0; i < 2; i++) {
639 result.buf.reg[i] = Mul(lhs.buf.reg[i], p);
640 }
641 return result;
642 }
643};
644
645// 8x1 := 8x1 * 8x1
646template <>
647struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
648 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
649 const RegBlockInt32<8, 1>& rhs) {
650 RegBlockInt32<8, 1> result;
651 for (int i = 0; i < 2; i++) {
652 result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]);
653 }
654 return result;
655 }
656};
657
658// 8x4 := 8x4 * 1x4
659template <>
660struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
661 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
662 const RegBlockInt32<1, 4>& rhs) {
663 RegBlockInt32<8, 4> result;
664 const Int32x4 p = rhs.buf.reg[0];
665 for (int i = 0; i < 2; i++) {
666 result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p);
667 result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p);
668 result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p);
669 result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p);
670 }
671 return result;
672 }
673};
674
675// 8x4 := 8x4 * 8x1
676template <>
677struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
678 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
679 const RegBlockInt32<8, 1>& rhs) {
680 RegBlockInt32<8, 4> result;
681 const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]};
682 for (int i = 0; i < 4; i++) {
683 for (int j = 0; j < 2; j++) {
684 const int k = j + 2 * i;
685 result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]);
686 }
687 }
688 return result;
689 }
690};
691
692// Rx1 += Rx1 * 1x1
693template <int Rows>
694struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
695 RegBlockInt32<Rows, 1>> {
696 static void Run(const RegBlockInt32<Rows, 1>& lhs,
697 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) {
698 const std::int32_t p = rhs.buf.reg[0];
699 for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) {
700 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
701 }
702 }
703};
704
705// RxC += Rx1 * 1x1
706template <int Rows, int Cols>
707struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
708 RegBlockInt32<Rows, Cols>> {
709 static void Run(const RegBlockInt32<Rows, 1>& lhs,
710 const RegBlockInt32<1, 1>& rhs,
711 RegBlockInt32<Rows, Cols>* acc) {
712 const std::int32_t p = rhs.buf.reg[0];
713 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
714 for (int i = 0; i < kRegsPerCol; i++) {
715 const Int32x4 q = Mul(lhs.buf.reg[i], p);
716 for (int j = 0; j < Cols; j++) {
717 acc->buf.reg[i + j * kRegsPerCol] =
718 Add(acc->buf.reg[i + j * kRegsPerCol], q);
719 }
720 }
721 }
722};
723
724// 1xC += 1xC * 1x1
725template <int Cols>
726struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>,
727 RegBlockInt32<1, Cols>> {
728 static void Run(const RegBlockInt32<1, Cols>& lhs,
729 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
730 const std::int32_t p = rhs.buf.reg[0];
731 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
732 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
733 }
734 }
735};
736
737// RxC += 1x1 * 1x1
738template <int Rows, int Cols>
739struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
740 RegBlockInt32<Rows, Cols>> {
741 static void Run(const RegBlockInt32<1, 1>& lhs,
742 const RegBlockInt32<1, 1>& rhs,
743 RegBlockInt32<Rows, Cols>* acc) {
744 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
745 for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) {
746 acc->buf.reg[i] = Add(acc->buf.reg[i], p);
747 }
748 }
749};
750
751// 1x1 += 1x1 * 1x1
752template <>
753struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
754 RegBlockInt32<1, 1>> {
755 static void Run(const RegBlockInt32<1, 1>& lhs,
756 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) {
757 MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]);
758 }
759};
760
761// Rx4 += Rx1 * 1x4
762template <int Rows>
763struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>,
764 RegBlockInt32<Rows, 4>> {
765 static void Run(const RegBlockInt32<Rows, 1>& lhs,
766 const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) {
767 const Int32x4 p = rhs.buf.reg[0];
768 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
769 for (int i = 0; i < kRegsPerCol; i++) {
770 MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]);
771 MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]);
772 MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]);
773 MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]);
774 }
775 }
776};
777
778// Rx4 += 1x4 * 1x1
779template <int Rows>
780struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
781 RegBlockInt32<Rows, 4>> {
782 static void Run(const RegBlockInt32<1, 4>& lhs,
783 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) {
784 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
785 Int32x4 q[4];
786 q[0] = DupLane<0>(p);
787 q[1] = DupLane<1>(p);
788 q[2] = DupLane<2>(p);
789 q[3] = DupLane<3>(p);
790 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
791 for (int i = 0; i < kRegsPerCol; i++) {
792 for (int j = 0; j < 4; j++) {
793 acc->buf.reg[i + j * kRegsPerCol] =
794 Add(q[j], acc->buf.reg[i + j * kRegsPerCol]);
795 }
796 }
797 }
798};
799
800// 1xC += 1x1 * 1x1
801template <int Cols>
802struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
803 RegBlockInt32<1, Cols>> {
804 static void Run(const RegBlockInt32<1, 1>& lhs,
805 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
806 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
807 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
808 acc->buf.reg[i] = Add(acc->buf.reg[i], p);
809 }
810 }
811};
812
813// 1x4 += 1x4 * 1x1
814template <>
815struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
816 RegBlockInt32<1, 4>> {
817 static void Run(const RegBlockInt32<1, 4>& lhs,
818 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) {
819 const std::int32_t p = rhs.buf.reg[0];
820 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
821 }
822};
823
824// 4xC += 4x1 * 1x1
825template <int Cols>
826struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
827 RegBlockInt32<4, Cols>> {
828 static void Run(const RegBlockInt32<4, 1>& lhs,
829 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) {
830 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
831 for (int i = 0; i < Cols; i++) {
832 acc->buf.reg[i] = Add(p, acc->buf.reg[i]);
833 }
834 }
835};
836
837// 4x1 += 4x1 * 1x1
838template <>
839struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
840 RegBlockInt32<4, 1>> {
841 static void Run(const RegBlockInt32<4, 1>& lhs,
842 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) {
843 const std::int32_t p = rhs.buf.reg[0];
844 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
845 }
846};
847
848} // namespace gemmlowp
849
850#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
851