1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// output_sse.h: optimized SSE4.2 specializations of the templates in output.h.
16
17#ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
18#define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
19
20#include "output.h"
21
22#include <smmintrin.h>
23
24namespace gemmlowp {
25
26template <>
27struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28 RegBufferInt32<4>> {
29 typedef RegBufferInt32<4> InputType;
30 typedef RegBufferUint8<4> OutputType;
31
32 typedef OutputStageSaturatingCastToUint8 OutputStage;
33
34 OutputStageEvalBufferImpl(const OutputStage&) {}
35
36 OutputType Eval(InputType input) const {
37 OutputType output;
38 __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
39 __m128i res_8 = _mm_packus_epi16(res_16, res_16);
40 output.reg[0] = _mm_cvtsi128_si32(res_8);
41 return output;
42 }
43};
44
45template <>
46struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
47 RegBufferInt32<8>> {
48 typedef RegBufferInt32<8> InputType;
49 typedef RegBufferUint8<8> OutputType;
50
51 typedef OutputStageSaturatingCastToUint8 OutputStage;
52
53 OutputStageEvalBufferImpl(const OutputStage&) {}
54
55 OutputType Eval(InputType input) const {
56 OutputType output;
57 __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]);
58 __m128i res_8 = _mm_packus_epi16(res_16, res_16);
59 output.reg[0] = _mm_extract_epi32(res_8, 0);
60 output.reg[1] = _mm_extract_epi32(res_8, 1);
61 return output;
62 }
63};
64
65template <>
66struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
67 RegBufferInt32<16>> {
68 typedef RegBufferInt32<16> InputType;
69 typedef RegBufferUint8<16> OutputType;
70
71 typedef OutputStageSaturatingCastToUint8 OutputStage;
72
73 OutputStageEvalBufferImpl(const OutputStage&) {}
74
75 OutputType Eval(InputType input) const {
76 OutputType output;
77 __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
78 __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
79 output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
80 return output;
81 }
82};
83
84template <>
85struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
86 RegBufferInt32<32>> {
87 typedef RegBufferInt32<32> InputType;
88 typedef RegBufferUint8<32> OutputType;
89
90 typedef OutputStageSaturatingCastToUint8 OutputStage;
91
92 OutputStageEvalBufferImpl(const OutputStage&) {}
93
94 OutputType Eval(InputType input) const {
95 OutputType output;
96 __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
97 __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
98 output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
99 __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]);
100 __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]);
101 output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3);
102 return output;
103 }
104};
105
106template <>
107struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
108 RegBufferInt32<4>> {
109 typedef RegBufferInt32<4> InputType;
110 typedef RegBufferInt16<4> OutputType;
111
112 typedef OutputStageSaturatingCastToInt16 OutputStage;
113
114 OutputStageEvalBufferImpl(const OutputStage&) {}
115
116 OutputType Eval(InputType input) const {
117 OutputType output;
118 __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
119 output.reg[0] = _mm_extract_epi16(res_16, 0);
120 output.reg[1] = _mm_extract_epi16(res_16, 1);
121 output.reg[2] = _mm_extract_epi16(res_16, 2);
122 output.reg[3] = _mm_extract_epi16(res_16, 3);
123 return output;
124 }
125};
126
127template <>
128struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
129 RegBufferInt32<8>> {
130 typedef RegBufferInt32<8> InputType;
131 typedef RegBufferInt16<8> OutputType;
132
133 typedef OutputStageSaturatingCastToInt16 OutputStage;
134
135 OutputStageEvalBufferImpl(const OutputStage&) {}
136
137 OutputType Eval(InputType input) const {
138 OutputType output;
139 output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
140 return output;
141 }
142};
143
144template <>
145struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
146 RegBufferInt32<16>> {
147 typedef RegBufferInt32<16> InputType;
148 typedef RegBufferInt16<16> OutputType;
149
150 typedef OutputStageSaturatingCastToInt16 OutputStage;
151
152 OutputStageEvalBufferImpl(const OutputStage&) {}
153
154 OutputType Eval(InputType input) const {
155 OutputType output;
156 output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
157 output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
158 return output;
159 }
160};
161
162template <>
163struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
164 RegBufferInt32<32>> {
165 typedef RegBufferInt32<32> InputType;
166 typedef RegBufferInt16<32> OutputType;
167
168 typedef OutputStageSaturatingCastToInt16 OutputStage;
169
170 OutputStageEvalBufferImpl(const OutputStage&) {}
171
172 OutputType Eval(InputType input) const {
173 OutputType output;
174 output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
175 output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
176 output.reg[2] = _mm_packs_epi32(input.reg[4], input.reg[5]);
177 output.reg[3] = _mm_packs_epi32(input.reg[6], input.reg[7]);
178 return output;
179 }
180};
181
182template <typename DstType>
183struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
184 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
185 int col) {
186 if (DstType::kOrder == MapOrder::ColMajor) {
187 StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
188 } else {
189 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
190 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
191 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
192 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
193 }
194 }
195};
196
197template <typename DstType>
198struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
199 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
200 int col) {
201 if (DstType::kOrder == MapOrder::ColMajor) {
202 StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
203 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
204 } else {
205 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
206 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
207 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
208 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
209 *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
210 *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
211 *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
212 *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
213 }
214 }
215};
216
217template <typename DstType>
218struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
219 static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
220 int col) {
221 *dst->data(row + 0, col) = src.buf.reg[0];
222 *dst->data(row + 1, col) = src.buf.reg[1];
223 *dst->data(row + 2, col) = src.buf.reg[2];
224 *dst->data(row + 3, col) = src.buf.reg[3];
225 }
226};
227
228template <typename DstType>
229struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
230 static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
231 int col) {
232 if (DstType::kOrder == MapOrder::ColMajor) {
233 StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
234 } else {
235 *dst->data(row + 0, col) = _mm_extract_epi16(src.buf.reg[0], 0);
236 *dst->data(row + 1, col) = _mm_extract_epi16(src.buf.reg[0], 1);
237 *dst->data(row + 2, col) = _mm_extract_epi16(src.buf.reg[0], 2);
238 *dst->data(row + 3, col) = _mm_extract_epi16(src.buf.reg[0], 3);
239 *dst->data(row + 4, col) = _mm_extract_epi16(src.buf.reg[0], 4);
240 *dst->data(row + 5, col) = _mm_extract_epi16(src.buf.reg[0], 5);
241 *dst->data(row + 6, col) = _mm_extract_epi16(src.buf.reg[0], 6);
242 *dst->data(row + 7, col) = _mm_extract_epi16(src.buf.reg[0], 7);
243 }
244 }
245};
246
247inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
248 __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
249 __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
250 __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]);
251 __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]);
252
253 RegBlockInt32<4, 4> result;
254 result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1);
255 result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1);
256 result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3);
257 result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3);
258 return result;
259}
260
261template <typename DstType>
262struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
263 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
264 int col) {
265 if (DstType::kOrder == MapOrder::ColMajor) {
266 for (int i = 0; i < 4; i++) {
267 StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
268 }
269 } else {
270 const auto transpose = Transpose(src);
271 for (int i = 0; i < 4; i++) {
272 StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
273 }
274 }
275 }
276};
277
278template <typename DstType>
279struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
280 static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
281 int col) {
282 std::int16_t buf[16];
283 StoreInt16x8(buf + 0, src.buf.reg[0]);
284 StoreInt16x8(buf + 8, src.buf.reg[1]);
285 for (int i = 0; i < 4; i++) {
286 for (int j = 0; j < 4; j++) {
287 *dst->data(row + i, col + j) = buf[i + 4 * j];
288 }
289 }
290 }
291};
292
293template <typename DstType>
294struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
295 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
296 int col) {
297 if (DstType::kOrder == MapOrder::ColMajor) {
298 for (int i = 0; i < 4; i++) {
299 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
300 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
301 }
302 } else {
303 RegBlockInt32<4, 4> top;
304 top.buf.reg[0] = src.buf.reg[0];
305 top.buf.reg[1] = src.buf.reg[2];
306 top.buf.reg[2] = src.buf.reg[4];
307 top.buf.reg[3] = src.buf.reg[6];
308 const auto transpose_top = Transpose(top);
309 for (int i = 0; i < 4; i++) {
310 StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
311 }
312 RegBlockInt32<4, 4> bottom;
313 bottom.buf.reg[0] = src.buf.reg[1];
314 bottom.buf.reg[1] = src.buf.reg[3];
315 bottom.buf.reg[2] = src.buf.reg[5];
316 bottom.buf.reg[3] = src.buf.reg[7];
317 const auto transpose_bottom = Transpose(bottom);
318 for (int i = 0; i < 4; i++) {
319 StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
320 }
321 }
322 }
323};
324
325template <typename DstType>
326struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
327 static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
328 int col) {
329 if (DstType::kOrder == MapOrder::ColMajor) {
330 for (int i = 0; i < 4; i++) {
331 StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
332 }
333 } else {
334 std::int16_t buf[32];
335 StoreInt16x8(buf + 0, src.buf.reg[0]);
336 StoreInt16x8(buf + 8, src.buf.reg[1]);
337 StoreInt16x8(buf + 16, src.buf.reg[2]);
338 StoreInt16x8(buf + 24, src.buf.reg[3]);
339 for (int i = 0; i < 8; i++) {
340 for (int j = 0; j < 4; j++) {
341 *dst->data(row + i, col + j) = buf[i + 8 * j];
342 }
343 }
344 }
345 }
346};
347
348template <typename DstType>
349struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
350 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
351 int col) {
352 if (DstType::kOrder == MapOrder::ColMajor) {
353 for (int i = 0; i < 8; i++) {
354 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
355 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
356 }
357 } else {
358 RegBlockInt32<4, 4> top_left;
359 top_left.buf.reg[0] = src.buf.reg[0];
360 top_left.buf.reg[1] = src.buf.reg[2];
361 top_left.buf.reg[2] = src.buf.reg[4];
362 top_left.buf.reg[3] = src.buf.reg[6];
363 const auto transpose_top_left = Transpose(top_left);
364 for (int i = 0; i < 4; i++) {
365 StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
366 }
367 RegBlockInt32<4, 4> bottom_left;
368 bottom_left.buf.reg[0] = src.buf.reg[1];
369 bottom_left.buf.reg[1] = src.buf.reg[3];
370 bottom_left.buf.reg[2] = src.buf.reg[5];
371 bottom_left.buf.reg[3] = src.buf.reg[7];
372 const auto transpose_bottom_left = Transpose(bottom_left);
373 for (int i = 0; i < 4; i++) {
374 StoreInt32x4(dst->data(row + 4 + i, col),
375 transpose_bottom_left.buf.reg[i]);
376 }
377 RegBlockInt32<4, 4> top_right;
378 top_right.buf.reg[0] = src.buf.reg[8];
379 top_right.buf.reg[1] = src.buf.reg[10];
380 top_right.buf.reg[2] = src.buf.reg[12];
381 top_right.buf.reg[3] = src.buf.reg[14];
382 const auto transpose_top_right = Transpose(top_right);
383 for (int i = 0; i < 4; i++) {
384 StoreInt32x4(dst->data(row + i, col + 4),
385 transpose_top_right.buf.reg[i]);
386 }
387 RegBlockInt32<4, 4> bottom_right;
388 bottom_right.buf.reg[0] = src.buf.reg[9];
389 bottom_right.buf.reg[1] = src.buf.reg[11];
390 bottom_right.buf.reg[2] = src.buf.reg[13];
391 bottom_right.buf.reg[3] = src.buf.reg[15];
392 const auto transpose_bottom_right = Transpose(bottom_right);
393 for (int i = 0; i < 4; i++) {
394 StoreInt32x4(dst->data(row + 4 + i, col + 4),
395 transpose_bottom_right.buf.reg[i]);
396 }
397 }
398 }
399};
400
401template <typename DstType>
402struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
403 static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
404 int col) {
405 if (DstType::kOrder == MapOrder::ColMajor) {
406 for (int i = 0; i < 8; i++) {
407 StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
408 }
409 } else {
410 // top-left 4x4
411 __m128i t0 = _mm_unpacklo_epi16(src.buf.reg[0], src.buf.reg[1]);
412 __m128i t1 = _mm_unpacklo_epi16(src.buf.reg[2], src.buf.reg[3]);
413 __m128i u0 = _mm_unpacklo_epi32(t0, t1);
414 __m128i u1 = _mm_unpackhi_epi32(t0, t1);
415 // top-right 4x4
416 __m128i t2 = _mm_unpacklo_epi16(src.buf.reg[4], src.buf.reg[5]);
417 __m128i t3 = _mm_unpacklo_epi16(src.buf.reg[6], src.buf.reg[7]);
418 __m128i u2 = _mm_unpacklo_epi32(t2, t3);
419 __m128i u3 = _mm_unpackhi_epi32(t2, t3);
420 // bottom-left 4x4
421 __m128i t4 = _mm_unpackhi_epi16(src.buf.reg[0], src.buf.reg[1]);
422 __m128i t5 = _mm_unpackhi_epi16(src.buf.reg[2], src.buf.reg[3]);
423 __m128i u4 = _mm_unpacklo_epi32(t4, t5);
424 __m128i u5 = _mm_unpackhi_epi32(t4, t5);
425 // bottom-right 4x4
426 __m128i t6 = _mm_unpackhi_epi16(src.buf.reg[4], src.buf.reg[5]);
427 __m128i t7 = _mm_unpackhi_epi16(src.buf.reg[6], src.buf.reg[7]);
428 __m128i u6 = _mm_unpacklo_epi32(t6, t7);
429 __m128i u7 = _mm_unpackhi_epi32(t6, t7);
430
431 StoreInt16x8(dst->data(row + 0, col), _mm_unpacklo_epi64(u0, u2));
432 StoreInt16x8(dst->data(row + 1, col), _mm_unpackhi_epi64(u0, u2));
433 StoreInt16x8(dst->data(row + 2, col), _mm_unpacklo_epi64(u1, u3));
434 StoreInt16x8(dst->data(row + 3, col), _mm_unpackhi_epi64(u1, u3));
435 StoreInt16x8(dst->data(row + 4, col), _mm_unpacklo_epi64(u4, u6));
436 StoreInt16x8(dst->data(row + 5, col), _mm_unpackhi_epi64(u4, u6));
437 StoreInt16x8(dst->data(row + 6, col), _mm_unpacklo_epi64(u5, u7));
438 StoreInt16x8(dst->data(row + 7, col), _mm_unpackhi_epi64(u5, u7));
439 }
440 }
441};
442
443template <typename DstType>
444struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
445 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
446 int col) {
447 if (DstType::kOrder == MapOrder::ColMajor) {
448 *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
449 *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
450 *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
451 *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
452 } else {
453 StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
454 }
455 }
456};
457
458template <typename DstType>
459struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
460 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
461 int col) {
462 const std::uint32_t src_reg = src.buf.reg[0];
463 for (int i = 0; i < 4; i++) {
464 *dst->data(row + i, col) = (src_reg >> (8 * i));
465 }
466 }
467};
468
469template <typename DstType>
470struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
471 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
472 int col) {
473 for (int i = 0; i < 4; i++) {
474 *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
475 }
476 for (int i = 0; i < 4; i++) {
477 *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
478 }
479 }
480};
481
482template <typename DstType>
483struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
484 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
485 int col) {
486 for (int i = 0; i < 4; i++) {
487 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
488 }
489 }
490};
491
492template <typename DstType>
493struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
494 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
495 int col) {
496 std::uint8_t buf[16];
497 StoreUint8x16(buf, src.buf.reg[0]);
498 for (int c = 0; c < 4; c++) {
499 for (int r = 0; r < 4; r++) {
500 *dst->data(row + r, col + c) = buf[r + 4 * c];
501 }
502 }
503 }
504};
505
506template <typename DstType>
507struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
508 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
509 int col) {
510 std::uint8_t buf[32];
511 StoreUint8x16(buf, src.buf.reg[0]);
512 StoreUint8x16(buf + 16, src.buf.reg[1]);
513 for (int c = 0; c < 4; c++) {
514 for (int r = 0; r < 8; r++) {
515 *dst->data(row + r, col + c) = buf[r + 8 * c];
516 }
517 }
518 }
519};
520
521template <typename DstType>
522struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
523 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
524 int col) {
525 std::uint8_t buf[64];
526 StoreUint8x16(buf, src.buf.reg[0]);
527 StoreUint8x16(buf + 16, src.buf.reg[1]);
528 StoreUint8x16(buf + 32, src.buf.reg[2]);
529 StoreUint8x16(buf + 48, src.buf.reg[3]);
530 for (int c = 0; c < 8; c++) {
531 for (int r = 0; r < 8; r++) {
532 *dst->data(row + r, col + c) = buf[r + 8 * c];
533 }
534 }
535 }
536};
537
538// Specialization for MatrixMap, for performance.
539template <typename tScalar, MapOrder tOrder>
540struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, MatrixMap<tScalar, tOrder>> {
541 static void Run(const RegBlockUint8<8, 8>& src,
542 MatrixMap<tScalar, tOrder>* dst, int row, int col) {
543 std::uint8_t buf[64];
544 StoreUint8x16(buf, src.buf.reg[0]);
545 StoreUint8x16(buf + 16, src.buf.reg[1]);
546 StoreUint8x16(buf + 32, src.buf.reg[2]);
547 StoreUint8x16(buf + 48, src.buf.reg[3]);
548 // Make a local copy so that the compiler can prove that data_ does not
549 // alias &data_ or &stride_.
550 MatrixMap<tScalar, tOrder> local = *dst;
551 for (int c = 0; c < 8; c++) {
552 for (int r = 0; r < 8; r++) {
553 *local.data(row + r, col + c) = buf[r + 8 * c];
554 }
555 }
556 }
557};
558
559} // namespace gemmlowp
560
561#endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
562