1 | // Copyright 2015 Google Inc. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code |
16 | |
17 | #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ |
18 | #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ |
19 | |
20 | #include "simd_wrappers.h" |
21 | |
22 | namespace gemmlowp { |
23 | |
24 | template <typename SrcScalarType, int N> |
25 | struct LoadImpl<RegBlockInt32<4, N>, |
26 | MatrixMap<SrcScalarType, MapOrder::ColMajor>> { |
27 | static RegBlockInt32<4, N> Run( |
28 | const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, |
29 | int col) { |
30 | RegBlockInt32<4, N> result; |
31 | for (int i = 0; i < N; i++) { |
32 | result.buf.reg[i] = LoadInt32x4(src.data(row, col + i)); |
33 | } |
34 | return result; |
35 | } |
36 | }; |
37 | |
38 | template <typename SrcScalarType, int N> |
39 | struct LoadImpl<RegBlockInt32<8, N>, |
40 | MatrixMap<SrcScalarType, MapOrder::ColMajor>> { |
41 | static RegBlockInt32<8, N> Run( |
42 | const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, |
43 | int col) { |
44 | RegBlockInt32<8, N> result; |
45 | for (int i = 0; i < N; i++) { |
46 | result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i)); |
47 | result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i)); |
48 | } |
49 | return result; |
50 | } |
51 | }; |
52 | |
53 | template <typename SrcScalarType> |
54 | struct LoadImpl<RegBlockInt32<1, 4>, |
55 | MatrixMap<SrcScalarType, MapOrder::ColMajor>> { |
56 | static RegBlockInt32<1, 4> Run( |
57 | const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, |
58 | int col) { |
59 | RegBlockInt32<1, 4> result; |
60 | std::int32_t buf[4]; |
61 | for (int i = 0; i < 4; i++) { |
62 | buf[i] = src(row, col + i); |
63 | } |
64 | result.buf.reg[0] = LoadInt32x4(buf); |
65 | return result; |
66 | } |
67 | }; |
68 | |
69 | template <typename SrcScalarType> |
70 | struct LoadImpl<RegBlockInt32<1, 8>, |
71 | MatrixMap<SrcScalarType, MapOrder::ColMajor>> { |
72 | static RegBlockInt32<1, 8> Run( |
73 | const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, |
74 | int col) { |
75 | RegBlockInt32<1, 8> result; |
76 | std::int32_t buf[8]; |
77 | for (int i = 0; i < 8; i++) { |
78 | buf[i] = src(row, col + i); |
79 | } |
80 | result.buf.reg[0] = LoadInt32x4(buf); |
81 | result.buf.reg[1] = LoadInt32x4(buf + 4); |
82 | return result; |
83 | } |
84 | }; |
85 | |
86 | template <typename SrcScalarType> |
87 | struct LoadImpl<RegBlockInt32<4, 1>, |
88 | VectorMap<SrcScalarType, VectorShape::Col>> { |
89 | static RegBlockInt32<4, 1> Run( |
90 | const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) { |
91 | RegBlockInt32<4, 1> result; |
92 | result.buf.reg[0] = LoadInt32x4(src.data(pos)); |
93 | return result; |
94 | } |
95 | }; |
96 | |
97 | template <typename SrcScalarType> |
98 | struct LoadImpl<RegBlockInt32<4, 1>, |
99 | VectorDup<SrcScalarType, VectorShape::Col>> { |
100 | static RegBlockInt32<4, 1> Run( |
101 | const VectorDup<SrcScalarType, VectorShape::Col>& src, int) { |
102 | RegBlockInt32<4, 1> result; |
103 | result.buf.reg[0] = LoadInt32x4(src(0)); |
104 | return result; |
105 | } |
106 | }; |
107 | |
108 | template <typename SrcScalarType, int N> |
109 | struct LoadForBroadcastingImpl<RegBlockInt32<4, N>, |
110 | VectorMap<SrcScalarType, VectorShape::Col>> { |
111 | using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; |
112 | using RegisterBlockType = RegBlockInt32<4, N>; |
113 | using ResultBlockType = |
114 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
115 | SrcObjectType>::Type; |
116 | |
117 | static ResultBlockType Run(const SrcObjectType& src, int pos) { |
118 | ResultBlockType result; |
119 | static_assert(ResultBlockType::kRegisterCount == 1, "" ); |
120 | result.buf.reg[0] = LoadInt32x4(src.data(pos)); |
121 | return result; |
122 | } |
123 | }; |
124 | |
125 | template <typename SrcScalarType, int N> |
126 | struct LoadForBroadcastingImpl<RegBlockInt32<8, N>, |
127 | VectorMap<SrcScalarType, VectorShape::Col>> { |
128 | using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; |
129 | using RegisterBlockType = RegBlockInt32<8, N>; |
130 | using ResultBlockType = |
131 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
132 | SrcObjectType>::Type; |
133 | |
134 | static ResultBlockType Run(const SrcObjectType& src, int pos) { |
135 | ResultBlockType result; |
136 | static_assert(ResultBlockType::kRegisterCount == 2, "" ); |
137 | result.buf.reg[0] = LoadInt32x4(src.data(pos)); |
138 | result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); |
139 | return result; |
140 | } |
141 | }; |
142 | |
143 | template <typename SrcScalarType> |
144 | struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>, |
145 | VectorMap<SrcScalarType, VectorShape::Row>> { |
146 | using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; |
147 | using RegisterBlockType = RegBlockInt32<4, 1>; |
148 | using ResultBlockType = |
149 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
150 | SrcObjectType>::Type; |
151 | |
152 | static ResultBlockType Run(const SrcObjectType& src, int pos) { |
153 | ResultBlockType result; |
154 | result.buf.reg[0] = src(pos); |
155 | return result; |
156 | } |
157 | }; |
158 | |
159 | template <typename SrcScalarType, int N> |
160 | struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>, |
161 | VectorMap<SrcScalarType, VectorShape::Row>> { |
162 | using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; |
163 | using RegisterBlockType = RegBlockInt32<N, 4>; |
164 | using ResultBlockType = |
165 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
166 | SrcObjectType>::Type; |
167 | |
168 | static ResultBlockType Run(const SrcObjectType& src, int pos) { |
169 | ResultBlockType result; |
170 | static_assert(ResultBlockType::kRegisterCount == 1, "" ); |
171 | result.buf.reg[0] = LoadInt32x4(src.data(pos)); |
172 | return result; |
173 | } |
174 | }; |
175 | |
176 | template <typename SrcScalarType, int N> |
177 | struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>, |
178 | VectorMap<SrcScalarType, VectorShape::Row>> { |
179 | using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; |
180 | using RegisterBlockType = RegBlockInt32<N, 8>; |
181 | using ResultBlockType = |
182 | typename LoadForBroadcastingRegisterBlock<RegisterBlockType, |
183 | SrcObjectType>::Type; |
184 | |
185 | static ResultBlockType Run(const SrcObjectType& src, int pos) { |
186 | ResultBlockType result; |
187 | static_assert(ResultBlockType::kRegisterCount == 2, "" ); |
188 | result.buf.reg[0] = LoadInt32x4(src.data(pos)); |
189 | result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); |
190 | return result; |
191 | } |
192 | }; |
193 | |
194 | // 4x1 := 4x1 + 1x1 |
195 | template <> |
196 | struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { |
197 | static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, |
198 | const RegBlockInt32<1, 1>& rhs) { |
199 | RegBlockInt32<4, 1> result; |
200 | result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
201 | return result; |
202 | } |
203 | }; |
204 | |
205 | // 1x4 := 1x4 + 1x1 |
206 | template <> |
207 | struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { |
208 | static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, |
209 | const RegBlockInt32<1, 1>& rhs) { |
210 | RegBlockInt32<1, 4> result; |
211 | result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
212 | return result; |
213 | } |
214 | }; |
215 | |
216 | // 4x1 := 4x1 + 4x1 |
217 | template <> |
218 | struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { |
219 | static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, |
220 | const RegBlockInt32<4, 1>& rhs) { |
221 | RegBlockInt32<4, 1> result; |
222 | result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); |
223 | return result; |
224 | } |
225 | }; |
226 | |
227 | // 1x4 := 1x4 + 1x4 |
228 | template <> |
229 | struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { |
230 | static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, |
231 | const RegBlockInt32<1, 4>& rhs) { |
232 | RegBlockInt32<1, 4> result; |
233 | result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); |
234 | return result; |
235 | } |
236 | }; |
237 | |
238 | // 4x4 := 4x4 + 1x4 |
239 | template <> |
240 | struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { |
241 | static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, |
242 | const RegBlockInt32<1, 4>& rhs) { |
243 | RegBlockInt32<4, 4> result; |
244 | result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); |
245 | result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); |
246 | result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); |
247 | result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); |
248 | return result; |
249 | } |
250 | }; |
251 | |
252 | // 4x4 := 4x4 + 4x1 |
253 | template <> |
254 | struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { |
255 | static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, |
256 | const RegBlockInt32<4, 1>& rhs) { |
257 | RegBlockInt32<4, 4> result; |
258 | result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); |
259 | result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]); |
260 | result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); |
261 | result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]); |
262 | return result; |
263 | } |
264 | }; |
265 | |
266 | // 8x1 := 8x1 + 1x1 |
267 | template <> |
268 | struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { |
269 | static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, |
270 | const RegBlockInt32<1, 1>& rhs) { |
271 | RegBlockInt32<8, 1> result; |
272 | const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); |
273 | for (int i = 0; i < 2; i++) { |
274 | result.buf.reg[i] = Add(lhs.buf.reg[i], p); |
275 | } |
276 | return result; |
277 | } |
278 | }; |
279 | |
280 | // 8x1 := 8x1 + 8x1 |
281 | template <> |
282 | struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { |
283 | static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, |
284 | const RegBlockInt32<8, 1>& rhs) { |
285 | RegBlockInt32<8, 1> result; |
286 | for (int i = 0; i < 2; i++) { |
287 | result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); |
288 | } |
289 | return result; |
290 | } |
291 | }; |
292 | |
293 | // 8x4 := 8x4 + 1x4 |
294 | template <> |
295 | struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { |
296 | static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, |
297 | const RegBlockInt32<1, 4>& rhs) { |
298 | RegBlockInt32<8, 4> result; |
299 | result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); |
300 | result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); |
301 | result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); |
302 | result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); |
303 | result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); |
304 | result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); |
305 | result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); |
306 | result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); |
307 | return result; |
308 | } |
309 | }; |
310 | |
311 | // 8x4 := 8x4 + 8x1 |
312 | template <> |
313 | struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { |
314 | static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, |
315 | const RegBlockInt32<8, 1>& rhs) { |
316 | RegBlockInt32<8, 4> result; |
317 | result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); |
318 | result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); |
319 | result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); |
320 | result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]); |
321 | result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]); |
322 | result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]); |
323 | result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]); |
324 | result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]); |
325 | return result; |
326 | } |
327 | }; |
328 | |
329 | // 1x8 := 1x8 + 1x8 |
330 | template <> |
331 | struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { |
332 | static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, |
333 | const RegBlockInt32<1, 8>& rhs) { |
334 | RegBlockInt32<1, 8> result; |
335 | result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); |
336 | result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); |
337 | return result; |
338 | } |
339 | }; |
340 | |
341 | // 1x8 := 1x8 + 1x1 |
342 | template <> |
343 | struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { |
344 | static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, |
345 | const RegBlockInt32<1, 1>& rhs) { |
346 | RegBlockInt32<1, 8> result; |
347 | result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
348 | result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); |
349 | return result; |
350 | } |
351 | }; |
352 | |
353 | // 4x1 := 4x1 + 1x1 |
354 | template <> |
355 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, |
356 | RegBlockInt32<1, 1>> { |
357 | static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, |
358 | const RegBlockInt32<1, 1>& rhs) { |
359 | RegBlockInt32<4, 1> result; |
360 | result.buf.reg[0] = SaturatingRoundingDoublingHighMul( |
361 | lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
362 | return result; |
363 | } |
364 | }; |
365 | |
366 | // 1x4 := 1x4 + 1x1 |
367 | template <> |
368 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, |
369 | RegBlockInt32<1, 1>> { |
370 | static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, |
371 | const RegBlockInt32<1, 1>& rhs) { |
372 | RegBlockInt32<1, 4> result; |
373 | result.buf.reg[0] = SaturatingRoundingDoublingHighMul( |
374 | lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
375 | return result; |
376 | } |
377 | }; |
378 | |
379 | // 4x1 := 4x1 + 4x1 |
380 | template <> |
381 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, |
382 | RegBlockInt32<4, 1>> { |
383 | static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, |
384 | const RegBlockInt32<4, 1>& rhs) { |
385 | RegBlockInt32<4, 1> result; |
386 | result.buf.reg[0] = |
387 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); |
388 | return result; |
389 | } |
390 | }; |
391 | |
392 | // 1x4 := 1x4 + 1x4 |
393 | template <> |
394 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, |
395 | RegBlockInt32<1, 4>> { |
396 | static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, |
397 | const RegBlockInt32<1, 4>& rhs) { |
398 | RegBlockInt32<1, 4> result; |
399 | result.buf.reg[0] = |
400 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); |
401 | return result; |
402 | } |
403 | }; |
404 | |
405 | // 4x4 := 4x4 + 1x4 |
406 | template <> |
407 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, |
408 | RegBlockInt32<1, 4>> { |
409 | static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, |
410 | const RegBlockInt32<1, 4>& rhs) { |
411 | RegBlockInt32<4, 4> result; |
412 | result.buf.reg[0] = SaturatingRoundingDoublingHighMul( |
413 | lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); |
414 | result.buf.reg[1] = SaturatingRoundingDoublingHighMul( |
415 | lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); |
416 | result.buf.reg[2] = SaturatingRoundingDoublingHighMul( |
417 | lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); |
418 | result.buf.reg[3] = SaturatingRoundingDoublingHighMul( |
419 | lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); |
420 | return result; |
421 | } |
422 | }; |
423 | |
424 | // 4x4 := 4x4 + 4x1 |
425 | template <> |
426 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, |
427 | RegBlockInt32<4, 1>> { |
428 | static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, |
429 | const RegBlockInt32<4, 1>& rhs) { |
430 | RegBlockInt32<4, 4> result; |
431 | result.buf.reg[0] = |
432 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); |
433 | result.buf.reg[1] = |
434 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]); |
435 | result.buf.reg[2] = |
436 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); |
437 | result.buf.reg[3] = |
438 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]); |
439 | return result; |
440 | } |
441 | }; |
442 | |
443 | // 8x1 := 8x1 + 1x1 |
444 | template <> |
445 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, |
446 | RegBlockInt32<1, 1>> { |
447 | static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, |
448 | const RegBlockInt32<1, 1>& rhs) { |
449 | RegBlockInt32<8, 1> result; |
450 | const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); |
451 | for (int i = 0; i < 2; i++) { |
452 | result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p); |
453 | } |
454 | return result; |
455 | } |
456 | }; |
457 | |
458 | // 8x1 := 8x1 + 8x1 |
459 | template <> |
460 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, |
461 | RegBlockInt32<8, 1>> { |
462 | static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, |
463 | const RegBlockInt32<8, 1>& rhs) { |
464 | RegBlockInt32<8, 1> result; |
465 | for (int i = 0; i < 2; i++) { |
466 | result.buf.reg[i] = |
467 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]); |
468 | } |
469 | return result; |
470 | } |
471 | }; |
472 | |
473 | // 8x4 := 8x4 + 1x4 |
474 | template <> |
475 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, |
476 | RegBlockInt32<1, 4>> { |
477 | static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, |
478 | const RegBlockInt32<1, 4>& rhs) { |
479 | RegBlockInt32<8, 4> result; |
480 | result.buf.reg[0] = SaturatingRoundingDoublingHighMul( |
481 | lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); |
482 | result.buf.reg[1] = SaturatingRoundingDoublingHighMul( |
483 | lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); |
484 | result.buf.reg[2] = SaturatingRoundingDoublingHighMul( |
485 | lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); |
486 | result.buf.reg[3] = SaturatingRoundingDoublingHighMul( |
487 | lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); |
488 | result.buf.reg[4] = SaturatingRoundingDoublingHighMul( |
489 | lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); |
490 | result.buf.reg[5] = SaturatingRoundingDoublingHighMul( |
491 | lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); |
492 | result.buf.reg[6] = SaturatingRoundingDoublingHighMul( |
493 | lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); |
494 | result.buf.reg[7] = SaturatingRoundingDoublingHighMul( |
495 | lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); |
496 | return result; |
497 | } |
498 | }; |
499 | |
500 | // 8x4 := 8x4 + 8x1 |
501 | template <> |
502 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, |
503 | RegBlockInt32<8, 1>> { |
504 | static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, |
505 | const RegBlockInt32<8, 1>& rhs) { |
506 | RegBlockInt32<8, 4> result; |
507 | result.buf.reg[0] = |
508 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); |
509 | result.buf.reg[1] = |
510 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); |
511 | result.buf.reg[2] = |
512 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); |
513 | result.buf.reg[3] = |
514 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]); |
515 | result.buf.reg[4] = |
516 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]); |
517 | result.buf.reg[5] = |
518 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]); |
519 | result.buf.reg[6] = |
520 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]); |
521 | result.buf.reg[7] = |
522 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]); |
523 | return result; |
524 | } |
525 | }; |
526 | |
527 | // 1x8 := 1x8 + 1x8 |
528 | template <> |
529 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, |
530 | RegBlockInt32<1, 8>> { |
531 | static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, |
532 | const RegBlockInt32<1, 8>& rhs) { |
533 | RegBlockInt32<1, 8> result; |
534 | result.buf.reg[0] = |
535 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); |
536 | result.buf.reg[1] = |
537 | SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); |
538 | return result; |
539 | } |
540 | }; |
541 | |
542 | // 1x8 := 1x8 + 1x1 |
543 | template <> |
544 | struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, |
545 | RegBlockInt32<1, 1>> { |
546 | static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, |
547 | const RegBlockInt32<1, 1>& rhs) { |
548 | RegBlockInt32<1, 8> result; |
549 | result.buf.reg[0] = SaturatingRoundingDoublingHighMul( |
550 | lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
551 | result.buf.reg[1] = SaturatingRoundingDoublingHighMul( |
552 | lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); |
553 | return result; |
554 | } |
555 | }; |
556 | |
557 | // 4x1 := 4x1 * 1x1 |
558 | template <> |
559 | struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { |
560 | static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, |
561 | const RegBlockInt32<1, 1>& rhs) { |
562 | RegBlockInt32<4, 1> result; |
563 | result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); |
564 | return result; |
565 | } |
566 | }; |
567 | |
568 | // 4x1 := 4x1 * 4x1 |
569 | template <> |
570 | struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { |
571 | static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, |
572 | const RegBlockInt32<4, 1>& rhs) { |
573 | RegBlockInt32<4, 1> result; |
574 | result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); |
575 | return result; |
576 | } |
577 | }; |
578 | |
579 | // 1x4 := 1x4 * 1x4 |
580 | template <> |
581 | struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { |
582 | static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, |
583 | const RegBlockInt32<1, 4>& rhs) { |
584 | RegBlockInt32<1, 4> result; |
585 | result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); |
586 | return result; |
587 | } |
588 | }; |
589 | |
590 | // 1x4 := 1x4 * 1x1 |
591 | template <> |
592 | struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { |
593 | static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, |
594 | const RegBlockInt32<1, 1>& rhs) { |
595 | RegBlockInt32<1, 4> result; |
596 | result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); |
597 | return result; |
598 | } |
599 | }; |
600 | |
601 | // 4x4 := 4x4 * 1x4 |
602 | template <> |
603 | struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { |
604 | static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, |
605 | const RegBlockInt32<1, 4>& rhs) { |
606 | RegBlockInt32<4, 4> result; |
607 | const Int32x4 p = rhs.buf.reg[0]; |
608 | result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p); |
609 | result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p); |
610 | result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p); |
611 | result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p); |
612 | return result; |
613 | } |
614 | }; |
615 | |
616 | // 4x4 := 4x4 * 4x1 |
617 | template <> |
618 | struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { |
619 | static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, |
620 | const RegBlockInt32<4, 1>& rhs) { |
621 | RegBlockInt32<4, 4> result; |
622 | const Int32x4 p = rhs.buf.reg[0]; |
623 | result.buf.reg[0] = Mul(lhs.buf.reg[0], p); |
624 | result.buf.reg[1] = Mul(lhs.buf.reg[1], p); |
625 | result.buf.reg[2] = Mul(lhs.buf.reg[2], p); |
626 | result.buf.reg[3] = Mul(lhs.buf.reg[3], p); |
627 | return result; |
628 | } |
629 | }; |
630 | |
631 | // 8x1 := 8x1 * 1x1 |
632 | template <> |
633 | struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { |
634 | static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, |
635 | const RegBlockInt32<1, 1>& rhs) { |
636 | RegBlockInt32<8, 1> result; |
637 | const std::int32_t p = rhs.buf.reg[0]; |
638 | for (int i = 0; i < 2; i++) { |
639 | result.buf.reg[i] = Mul(lhs.buf.reg[i], p); |
640 | } |
641 | return result; |
642 | } |
643 | }; |
644 | |
645 | // 8x1 := 8x1 * 8x1 |
646 | template <> |
647 | struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { |
648 | static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, |
649 | const RegBlockInt32<8, 1>& rhs) { |
650 | RegBlockInt32<8, 1> result; |
651 | for (int i = 0; i < 2; i++) { |
652 | result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]); |
653 | } |
654 | return result; |
655 | } |
656 | }; |
657 | |
658 | // 8x4 := 8x4 * 1x4 |
659 | template <> |
660 | struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { |
661 | static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, |
662 | const RegBlockInt32<1, 4>& rhs) { |
663 | RegBlockInt32<8, 4> result; |
664 | const Int32x4 p = rhs.buf.reg[0]; |
665 | for (int i = 0; i < 2; i++) { |
666 | result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p); |
667 | result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p); |
668 | result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p); |
669 | result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p); |
670 | } |
671 | return result; |
672 | } |
673 | }; |
674 | |
675 | // 8x4 := 8x4 * 8x1 |
676 | template <> |
677 | struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { |
678 | static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, |
679 | const RegBlockInt32<8, 1>& rhs) { |
680 | RegBlockInt32<8, 4> result; |
681 | const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]}; |
682 | for (int i = 0; i < 4; i++) { |
683 | for (int j = 0; j < 2; j++) { |
684 | const int k = j + 2 * i; |
685 | result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]); |
686 | } |
687 | } |
688 | return result; |
689 | } |
690 | }; |
691 | |
692 | // Rx1 += Rx1 * 1x1 |
693 | template <int Rows> |
694 | struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, |
695 | RegBlockInt32<Rows, 1>> { |
696 | static void Run(const RegBlockInt32<Rows, 1>& lhs, |
697 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) { |
698 | const std::int32_t p = rhs.buf.reg[0]; |
699 | for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) { |
700 | MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); |
701 | } |
702 | } |
703 | }; |
704 | |
705 | // RxC += Rx1 * 1x1 |
706 | template <int Rows, int Cols> |
707 | struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, |
708 | RegBlockInt32<Rows, Cols>> { |
709 | static void Run(const RegBlockInt32<Rows, 1>& lhs, |
710 | const RegBlockInt32<1, 1>& rhs, |
711 | RegBlockInt32<Rows, Cols>* acc) { |
712 | const std::int32_t p = rhs.buf.reg[0]; |
713 | static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; |
714 | for (int i = 0; i < kRegsPerCol; i++) { |
715 | const Int32x4 q = Mul(lhs.buf.reg[i], p); |
716 | for (int j = 0; j < Cols; j++) { |
717 | acc->buf.reg[i + j * kRegsPerCol] = |
718 | Add(acc->buf.reg[i + j * kRegsPerCol], q); |
719 | } |
720 | } |
721 | } |
722 | }; |
723 | |
724 | // 1xC += 1xC * 1x1 |
725 | template <int Cols> |
726 | struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>, |
727 | RegBlockInt32<1, Cols>> { |
728 | static void Run(const RegBlockInt32<1, Cols>& lhs, |
729 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { |
730 | const std::int32_t p = rhs.buf.reg[0]; |
731 | for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { |
732 | MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); |
733 | } |
734 | } |
735 | }; |
736 | |
737 | // RxC += 1x1 * 1x1 |
738 | template <int Rows, int Cols> |
739 | struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, |
740 | RegBlockInt32<Rows, Cols>> { |
741 | static void Run(const RegBlockInt32<1, 1>& lhs, |
742 | const RegBlockInt32<1, 1>& rhs, |
743 | RegBlockInt32<Rows, Cols>* acc) { |
744 | const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); |
745 | for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) { |
746 | acc->buf.reg[i] = Add(acc->buf.reg[i], p); |
747 | } |
748 | } |
749 | }; |
750 | |
751 | // 1x1 += 1x1 * 1x1 |
752 | template <> |
753 | struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, |
754 | RegBlockInt32<1, 1>> { |
755 | static void Run(const RegBlockInt32<1, 1>& lhs, |
756 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) { |
757 | MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]); |
758 | } |
759 | }; |
760 | |
761 | // Rx4 += Rx1 * 1x4 |
762 | template <int Rows> |
763 | struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>, |
764 | RegBlockInt32<Rows, 4>> { |
765 | static void Run(const RegBlockInt32<Rows, 1>& lhs, |
766 | const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) { |
767 | const Int32x4 p = rhs.buf.reg[0]; |
768 | static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; |
769 | for (int i = 0; i < kRegsPerCol; i++) { |
770 | MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]); |
771 | MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]); |
772 | MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]); |
773 | MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]); |
774 | } |
775 | } |
776 | }; |
777 | |
778 | // Rx4 += 1x4 * 1x1 |
779 | template <int Rows> |
780 | struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, |
781 | RegBlockInt32<Rows, 4>> { |
782 | static void Run(const RegBlockInt32<1, 4>& lhs, |
783 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) { |
784 | const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); |
785 | Int32x4 q[4]; |
786 | q[0] = DupLane<0>(p); |
787 | q[1] = DupLane<1>(p); |
788 | q[2] = DupLane<2>(p); |
789 | q[3] = DupLane<3>(p); |
790 | static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; |
791 | for (int i = 0; i < kRegsPerCol; i++) { |
792 | for (int j = 0; j < 4; j++) { |
793 | acc->buf.reg[i + j * kRegsPerCol] = |
794 | Add(q[j], acc->buf.reg[i + j * kRegsPerCol]); |
795 | } |
796 | } |
797 | } |
798 | }; |
799 | |
800 | // 1xC += 1x1 * 1x1 |
801 | template <int Cols> |
802 | struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, |
803 | RegBlockInt32<1, Cols>> { |
804 | static void Run(const RegBlockInt32<1, 1>& lhs, |
805 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { |
806 | const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); |
807 | for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { |
808 | acc->buf.reg[i] = Add(acc->buf.reg[i], p); |
809 | } |
810 | } |
811 | }; |
812 | |
813 | // 1x4 += 1x4 * 1x1 |
814 | template <> |
815 | struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, |
816 | RegBlockInt32<1, 4>> { |
817 | static void Run(const RegBlockInt32<1, 4>& lhs, |
818 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) { |
819 | const std::int32_t p = rhs.buf.reg[0]; |
820 | MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); |
821 | } |
822 | }; |
823 | |
824 | // 4xC += 4x1 * 1x1 |
825 | template <int Cols> |
826 | struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, |
827 | RegBlockInt32<4, Cols>> { |
828 | static void Run(const RegBlockInt32<4, 1>& lhs, |
829 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) { |
830 | const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); |
831 | for (int i = 0; i < Cols; i++) { |
832 | acc->buf.reg[i] = Add(p, acc->buf.reg[i]); |
833 | } |
834 | } |
835 | }; |
836 | |
837 | // 4x1 += 4x1 * 1x1 |
838 | template <> |
839 | struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, |
840 | RegBlockInt32<4, 1>> { |
841 | static void Run(const RegBlockInt32<4, 1>& lhs, |
842 | const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) { |
843 | const std::int32_t p = rhs.buf.reg[0]; |
844 | MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); |
845 | } |
846 | }; |
847 | |
848 | } // namespace gemmlowp |
849 | |
850 | #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ |
851 | |