1 | // Copyright 2015 Google Inc. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // output_sse.h: optimized SSE4.2 specializations of the templates in output.h. |
16 | |
17 | #ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ |
18 | #define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ |
19 | |
20 | #include "output.h" |
21 | |
22 | #include <smmintrin.h> |
23 | |
24 | namespace gemmlowp { |
25 | |
26 | template <> |
27 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, |
28 | RegBufferInt32<4>> { |
29 | typedef RegBufferInt32<4> InputType; |
30 | typedef RegBufferUint8<4> OutputType; |
31 | |
32 | typedef OutputStageSaturatingCastToUint8 OutputStage; |
33 | |
34 | OutputStageEvalBufferImpl(const OutputStage&) {} |
35 | |
36 | OutputType Eval(InputType input) const { |
37 | OutputType output; |
38 | __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]); |
39 | __m128i res_8 = _mm_packus_epi16(res_16, res_16); |
40 | output.reg[0] = _mm_cvtsi128_si32(res_8); |
41 | return output; |
42 | } |
43 | }; |
44 | |
45 | template <> |
46 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, |
47 | RegBufferInt32<8>> { |
48 | typedef RegBufferInt32<8> InputType; |
49 | typedef RegBufferUint8<8> OutputType; |
50 | |
51 | typedef OutputStageSaturatingCastToUint8 OutputStage; |
52 | |
53 | OutputStageEvalBufferImpl(const OutputStage&) {} |
54 | |
55 | OutputType Eval(InputType input) const { |
56 | OutputType output; |
57 | __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]); |
58 | __m128i res_8 = _mm_packus_epi16(res_16, res_16); |
59 | output.reg[0] = _mm_extract_epi32(res_8, 0); |
60 | output.reg[1] = _mm_extract_epi32(res_8, 1); |
61 | return output; |
62 | } |
63 | }; |
64 | |
65 | template <> |
66 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, |
67 | RegBufferInt32<16>> { |
68 | typedef RegBufferInt32<16> InputType; |
69 | typedef RegBufferUint8<16> OutputType; |
70 | |
71 | typedef OutputStageSaturatingCastToUint8 OutputStage; |
72 | |
73 | OutputStageEvalBufferImpl(const OutputStage&) {} |
74 | |
75 | OutputType Eval(InputType input) const { |
76 | OutputType output; |
77 | __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); |
78 | __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); |
79 | output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); |
80 | return output; |
81 | } |
82 | }; |
83 | |
84 | template <> |
85 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, |
86 | RegBufferInt32<32>> { |
87 | typedef RegBufferInt32<32> InputType; |
88 | typedef RegBufferUint8<32> OutputType; |
89 | |
90 | typedef OutputStageSaturatingCastToUint8 OutputStage; |
91 | |
92 | OutputStageEvalBufferImpl(const OutputStage&) {} |
93 | |
94 | OutputType Eval(InputType input) const { |
95 | OutputType output; |
96 | __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); |
97 | __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); |
98 | output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); |
99 | __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]); |
100 | __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]); |
101 | output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3); |
102 | return output; |
103 | } |
104 | }; |
105 | |
106 | template <> |
107 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, |
108 | RegBufferInt32<4>> { |
109 | typedef RegBufferInt32<4> InputType; |
110 | typedef RegBufferInt16<4> OutputType; |
111 | |
112 | typedef OutputStageSaturatingCastToInt16 OutputStage; |
113 | |
114 | OutputStageEvalBufferImpl(const OutputStage&) {} |
115 | |
116 | OutputType Eval(InputType input) const { |
117 | OutputType output; |
118 | __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]); |
119 | output.reg[0] = _mm_extract_epi16(res_16, 0); |
120 | output.reg[1] = _mm_extract_epi16(res_16, 1); |
121 | output.reg[2] = _mm_extract_epi16(res_16, 2); |
122 | output.reg[3] = _mm_extract_epi16(res_16, 3); |
123 | return output; |
124 | } |
125 | }; |
126 | |
127 | template <> |
128 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, |
129 | RegBufferInt32<8>> { |
130 | typedef RegBufferInt32<8> InputType; |
131 | typedef RegBufferInt16<8> OutputType; |
132 | |
133 | typedef OutputStageSaturatingCastToInt16 OutputStage; |
134 | |
135 | OutputStageEvalBufferImpl(const OutputStage&) {} |
136 | |
137 | OutputType Eval(InputType input) const { |
138 | OutputType output; |
139 | output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]); |
140 | return output; |
141 | } |
142 | }; |
143 | |
144 | template <> |
145 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, |
146 | RegBufferInt32<16>> { |
147 | typedef RegBufferInt32<16> InputType; |
148 | typedef RegBufferInt16<16> OutputType; |
149 | |
150 | typedef OutputStageSaturatingCastToInt16 OutputStage; |
151 | |
152 | OutputStageEvalBufferImpl(const OutputStage&) {} |
153 | |
154 | OutputType Eval(InputType input) const { |
155 | OutputType output; |
156 | output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]); |
157 | output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]); |
158 | return output; |
159 | } |
160 | }; |
161 | |
162 | template <> |
163 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, |
164 | RegBufferInt32<32>> { |
165 | typedef RegBufferInt32<32> InputType; |
166 | typedef RegBufferInt16<32> OutputType; |
167 | |
168 | typedef OutputStageSaturatingCastToInt16 OutputStage; |
169 | |
170 | OutputStageEvalBufferImpl(const OutputStage&) {} |
171 | |
172 | OutputType Eval(InputType input) const { |
173 | OutputType output; |
174 | output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]); |
175 | output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]); |
176 | output.reg[2] = _mm_packs_epi32(input.reg[4], input.reg[5]); |
177 | output.reg[3] = _mm_packs_epi32(input.reg[6], input.reg[7]); |
178 | return output; |
179 | } |
180 | }; |
181 | |
182 | template <typename DstType> |
183 | struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { |
184 | static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, |
185 | int col) { |
186 | if (DstType::kOrder == MapOrder::ColMajor) { |
187 | StoreInt32x4(dst->data(row, col), src.buf.reg[0]); |
188 | } else { |
189 | *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); |
190 | *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); |
191 | *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); |
192 | *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); |
193 | } |
194 | } |
195 | }; |
196 | |
197 | template <typename DstType> |
198 | struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { |
199 | static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, |
200 | int col) { |
201 | if (DstType::kOrder == MapOrder::ColMajor) { |
202 | StoreInt32x4(dst->data(row, col), src.buf.reg[0]); |
203 | StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); |
204 | } else { |
205 | *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); |
206 | *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); |
207 | *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); |
208 | *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); |
209 | *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); |
210 | *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); |
211 | *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); |
212 | *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); |
213 | } |
214 | } |
215 | }; |
216 | |
217 | template <typename DstType> |
218 | struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> { |
219 | static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row, |
220 | int col) { |
221 | *dst->data(row + 0, col) = src.buf.reg[0]; |
222 | *dst->data(row + 1, col) = src.buf.reg[1]; |
223 | *dst->data(row + 2, col) = src.buf.reg[2]; |
224 | *dst->data(row + 3, col) = src.buf.reg[3]; |
225 | } |
226 | }; |
227 | |
228 | template <typename DstType> |
229 | struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> { |
230 | static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row, |
231 | int col) { |
232 | if (DstType::kOrder == MapOrder::ColMajor) { |
233 | StoreInt16x8(dst->data(row, col), src.buf.reg[0]); |
234 | } else { |
235 | *dst->data(row + 0, col) = _mm_extract_epi16(src.buf.reg[0], 0); |
236 | *dst->data(row + 1, col) = _mm_extract_epi16(src.buf.reg[0], 1); |
237 | *dst->data(row + 2, col) = _mm_extract_epi16(src.buf.reg[0], 2); |
238 | *dst->data(row + 3, col) = _mm_extract_epi16(src.buf.reg[0], 3); |
239 | *dst->data(row + 4, col) = _mm_extract_epi16(src.buf.reg[0], 4); |
240 | *dst->data(row + 5, col) = _mm_extract_epi16(src.buf.reg[0], 5); |
241 | *dst->data(row + 6, col) = _mm_extract_epi16(src.buf.reg[0], 6); |
242 | *dst->data(row + 7, col) = _mm_extract_epi16(src.buf.reg[0], 7); |
243 | } |
244 | } |
245 | }; |
246 | |
247 | inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { |
248 | __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]); |
249 | __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]); |
250 | __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]); |
251 | __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]); |
252 | |
253 | RegBlockInt32<4, 4> result; |
254 | result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1); |
255 | result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1); |
256 | result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3); |
257 | result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3); |
258 | return result; |
259 | } |
260 | |
261 | template <typename DstType> |
262 | struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { |
263 | static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, |
264 | int col) { |
265 | if (DstType::kOrder == MapOrder::ColMajor) { |
266 | for (int i = 0; i < 4; i++) { |
267 | StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]); |
268 | } |
269 | } else { |
270 | const auto transpose = Transpose(src); |
271 | for (int i = 0; i < 4; i++) { |
272 | StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]); |
273 | } |
274 | } |
275 | } |
276 | }; |
277 | |
278 | template <typename DstType> |
279 | struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> { |
280 | static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row, |
281 | int col) { |
282 | std::int16_t buf[16]; |
283 | StoreInt16x8(buf + 0, src.buf.reg[0]); |
284 | StoreInt16x8(buf + 8, src.buf.reg[1]); |
285 | for (int i = 0; i < 4; i++) { |
286 | for (int j = 0; j < 4; j++) { |
287 | *dst->data(row + i, col + j) = buf[i + 4 * j]; |
288 | } |
289 | } |
290 | } |
291 | }; |
292 | |
293 | template <typename DstType> |
294 | struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { |
295 | static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, |
296 | int col) { |
297 | if (DstType::kOrder == MapOrder::ColMajor) { |
298 | for (int i = 0; i < 4; i++) { |
299 | StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); |
300 | StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); |
301 | } |
302 | } else { |
303 | RegBlockInt32<4, 4> top; |
304 | top.buf.reg[0] = src.buf.reg[0]; |
305 | top.buf.reg[1] = src.buf.reg[2]; |
306 | top.buf.reg[2] = src.buf.reg[4]; |
307 | top.buf.reg[3] = src.buf.reg[6]; |
308 | const auto transpose_top = Transpose(top); |
309 | for (int i = 0; i < 4; i++) { |
310 | StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]); |
311 | } |
312 | RegBlockInt32<4, 4> bottom; |
313 | bottom.buf.reg[0] = src.buf.reg[1]; |
314 | bottom.buf.reg[1] = src.buf.reg[3]; |
315 | bottom.buf.reg[2] = src.buf.reg[5]; |
316 | bottom.buf.reg[3] = src.buf.reg[7]; |
317 | const auto transpose_bottom = Transpose(bottom); |
318 | for (int i = 0; i < 4; i++) { |
319 | StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]); |
320 | } |
321 | } |
322 | } |
323 | }; |
324 | |
325 | template <typename DstType> |
326 | struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> { |
327 | static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row, |
328 | int col) { |
329 | if (DstType::kOrder == MapOrder::ColMajor) { |
330 | for (int i = 0; i < 4; i++) { |
331 | StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); |
332 | } |
333 | } else { |
334 | std::int16_t buf[32]; |
335 | StoreInt16x8(buf + 0, src.buf.reg[0]); |
336 | StoreInt16x8(buf + 8, src.buf.reg[1]); |
337 | StoreInt16x8(buf + 16, src.buf.reg[2]); |
338 | StoreInt16x8(buf + 24, src.buf.reg[3]); |
339 | for (int i = 0; i < 8; i++) { |
340 | for (int j = 0; j < 4; j++) { |
341 | *dst->data(row + i, col + j) = buf[i + 8 * j]; |
342 | } |
343 | } |
344 | } |
345 | } |
346 | }; |
347 | |
348 | template <typename DstType> |
349 | struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { |
350 | static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, |
351 | int col) { |
352 | if (DstType::kOrder == MapOrder::ColMajor) { |
353 | for (int i = 0; i < 8; i++) { |
354 | StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); |
355 | StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); |
356 | } |
357 | } else { |
358 | RegBlockInt32<4, 4> top_left; |
359 | top_left.buf.reg[0] = src.buf.reg[0]; |
360 | top_left.buf.reg[1] = src.buf.reg[2]; |
361 | top_left.buf.reg[2] = src.buf.reg[4]; |
362 | top_left.buf.reg[3] = src.buf.reg[6]; |
363 | const auto transpose_top_left = Transpose(top_left); |
364 | for (int i = 0; i < 4; i++) { |
365 | StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]); |
366 | } |
367 | RegBlockInt32<4, 4> bottom_left; |
368 | bottom_left.buf.reg[0] = src.buf.reg[1]; |
369 | bottom_left.buf.reg[1] = src.buf.reg[3]; |
370 | bottom_left.buf.reg[2] = src.buf.reg[5]; |
371 | bottom_left.buf.reg[3] = src.buf.reg[7]; |
372 | const auto transpose_bottom_left = Transpose(bottom_left); |
373 | for (int i = 0; i < 4; i++) { |
374 | StoreInt32x4(dst->data(row + 4 + i, col), |
375 | transpose_bottom_left.buf.reg[i]); |
376 | } |
377 | RegBlockInt32<4, 4> top_right; |
378 | top_right.buf.reg[0] = src.buf.reg[8]; |
379 | top_right.buf.reg[1] = src.buf.reg[10]; |
380 | top_right.buf.reg[2] = src.buf.reg[12]; |
381 | top_right.buf.reg[3] = src.buf.reg[14]; |
382 | const auto transpose_top_right = Transpose(top_right); |
383 | for (int i = 0; i < 4; i++) { |
384 | StoreInt32x4(dst->data(row + i, col + 4), |
385 | transpose_top_right.buf.reg[i]); |
386 | } |
387 | RegBlockInt32<4, 4> bottom_right; |
388 | bottom_right.buf.reg[0] = src.buf.reg[9]; |
389 | bottom_right.buf.reg[1] = src.buf.reg[11]; |
390 | bottom_right.buf.reg[2] = src.buf.reg[13]; |
391 | bottom_right.buf.reg[3] = src.buf.reg[15]; |
392 | const auto transpose_bottom_right = Transpose(bottom_right); |
393 | for (int i = 0; i < 4; i++) { |
394 | StoreInt32x4(dst->data(row + 4 + i, col + 4), |
395 | transpose_bottom_right.buf.reg[i]); |
396 | } |
397 | } |
398 | } |
399 | }; |
400 | |
401 | template <typename DstType> |
402 | struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { |
403 | static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, |
404 | int col) { |
405 | if (DstType::kOrder == MapOrder::ColMajor) { |
406 | for (int i = 0; i < 8; i++) { |
407 | StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); |
408 | } |
409 | } else { |
410 | // top-left 4x4 |
411 | __m128i t0 = _mm_unpacklo_epi16(src.buf.reg[0], src.buf.reg[1]); |
412 | __m128i t1 = _mm_unpacklo_epi16(src.buf.reg[2], src.buf.reg[3]); |
413 | __m128i u0 = _mm_unpacklo_epi32(t0, t1); |
414 | __m128i u1 = _mm_unpackhi_epi32(t0, t1); |
415 | // top-right 4x4 |
416 | __m128i t2 = _mm_unpacklo_epi16(src.buf.reg[4], src.buf.reg[5]); |
417 | __m128i t3 = _mm_unpacklo_epi16(src.buf.reg[6], src.buf.reg[7]); |
418 | __m128i u2 = _mm_unpacklo_epi32(t2, t3); |
419 | __m128i u3 = _mm_unpackhi_epi32(t2, t3); |
420 | // bottom-left 4x4 |
421 | __m128i t4 = _mm_unpackhi_epi16(src.buf.reg[0], src.buf.reg[1]); |
422 | __m128i t5 = _mm_unpackhi_epi16(src.buf.reg[2], src.buf.reg[3]); |
423 | __m128i u4 = _mm_unpacklo_epi32(t4, t5); |
424 | __m128i u5 = _mm_unpackhi_epi32(t4, t5); |
425 | // bottom-right 4x4 |
426 | __m128i t6 = _mm_unpackhi_epi16(src.buf.reg[4], src.buf.reg[5]); |
427 | __m128i t7 = _mm_unpackhi_epi16(src.buf.reg[6], src.buf.reg[7]); |
428 | __m128i u6 = _mm_unpacklo_epi32(t6, t7); |
429 | __m128i u7 = _mm_unpackhi_epi32(t6, t7); |
430 | |
431 | StoreInt16x8(dst->data(row + 0, col), _mm_unpacklo_epi64(u0, u2)); |
432 | StoreInt16x8(dst->data(row + 1, col), _mm_unpackhi_epi64(u0, u2)); |
433 | StoreInt16x8(dst->data(row + 2, col), _mm_unpacklo_epi64(u1, u3)); |
434 | StoreInt16x8(dst->data(row + 3, col), _mm_unpackhi_epi64(u1, u3)); |
435 | StoreInt16x8(dst->data(row + 4, col), _mm_unpacklo_epi64(u4, u6)); |
436 | StoreInt16x8(dst->data(row + 5, col), _mm_unpackhi_epi64(u4, u6)); |
437 | StoreInt16x8(dst->data(row + 6, col), _mm_unpacklo_epi64(u5, u7)); |
438 | StoreInt16x8(dst->data(row + 7, col), _mm_unpackhi_epi64(u5, u7)); |
439 | } |
440 | } |
441 | }; |
442 | |
443 | template <typename DstType> |
444 | struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { |
445 | static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, |
446 | int col) { |
447 | if (DstType::kOrder == MapOrder::ColMajor) { |
448 | *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]); |
449 | *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]); |
450 | *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]); |
451 | *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]); |
452 | } else { |
453 | StoreInt32x4(dst->data(row, col), src.buf.reg[0]); |
454 | } |
455 | } |
456 | }; |
457 | |
458 | template <typename DstType> |
459 | struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { |
460 | static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, |
461 | int col) { |
462 | const std::uint32_t src_reg = src.buf.reg[0]; |
463 | for (int i = 0; i < 4; i++) { |
464 | *dst->data(row + i, col) = (src_reg >> (8 * i)); |
465 | } |
466 | } |
467 | }; |
468 | |
469 | template <typename DstType> |
470 | struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { |
471 | static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, |
472 | int col) { |
473 | for (int i = 0; i < 4; i++) { |
474 | *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i)); |
475 | } |
476 | for (int i = 0; i < 4; i++) { |
477 | *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i)); |
478 | } |
479 | } |
480 | }; |
481 | |
482 | template <typename DstType> |
483 | struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { |
484 | static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, |
485 | int col) { |
486 | for (int i = 0; i < 4; i++) { |
487 | *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); |
488 | } |
489 | } |
490 | }; |
491 | |
492 | template <typename DstType> |
493 | struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { |
494 | static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, |
495 | int col) { |
496 | std::uint8_t buf[16]; |
497 | StoreUint8x16(buf, src.buf.reg[0]); |
498 | for (int c = 0; c < 4; c++) { |
499 | for (int r = 0; r < 4; r++) { |
500 | *dst->data(row + r, col + c) = buf[r + 4 * c]; |
501 | } |
502 | } |
503 | } |
504 | }; |
505 | |
506 | template <typename DstType> |
507 | struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { |
508 | static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, |
509 | int col) { |
510 | std::uint8_t buf[32]; |
511 | StoreUint8x16(buf, src.buf.reg[0]); |
512 | StoreUint8x16(buf + 16, src.buf.reg[1]); |
513 | for (int c = 0; c < 4; c++) { |
514 | for (int r = 0; r < 8; r++) { |
515 | *dst->data(row + r, col + c) = buf[r + 8 * c]; |
516 | } |
517 | } |
518 | } |
519 | }; |
520 | |
521 | template <typename DstType> |
522 | struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { |
523 | static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, |
524 | int col) { |
525 | std::uint8_t buf[64]; |
526 | StoreUint8x16(buf, src.buf.reg[0]); |
527 | StoreUint8x16(buf + 16, src.buf.reg[1]); |
528 | StoreUint8x16(buf + 32, src.buf.reg[2]); |
529 | StoreUint8x16(buf + 48, src.buf.reg[3]); |
530 | for (int c = 0; c < 8; c++) { |
531 | for (int r = 0; r < 8; r++) { |
532 | *dst->data(row + r, col + c) = buf[r + 8 * c]; |
533 | } |
534 | } |
535 | } |
536 | }; |
537 | |
538 | // Specialization for MatrixMap, for performance. |
539 | template <typename tScalar, MapOrder tOrder> |
540 | struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, MatrixMap<tScalar, tOrder>> { |
541 | static void Run(const RegBlockUint8<8, 8>& src, |
542 | MatrixMap<tScalar, tOrder>* dst, int row, int col) { |
543 | std::uint8_t buf[64]; |
544 | StoreUint8x16(buf, src.buf.reg[0]); |
545 | StoreUint8x16(buf + 16, src.buf.reg[1]); |
546 | StoreUint8x16(buf + 32, src.buf.reg[2]); |
547 | StoreUint8x16(buf + 48, src.buf.reg[3]); |
548 | // Make a local copy so that the compiler can prove that data_ does not |
549 | // alias &data_ or &stride_. |
550 | MatrixMap<tScalar, tOrder> local = *dst; |
551 | for (int c = 0; c < 8; c++) { |
552 | for (int r = 0; r < 8; r++) { |
553 | *local.data(row + r, col + c) = buf[r + 8 * c]; |
554 | } |
555 | } |
556 | } |
557 | }; |
558 | |
559 | } // namespace gemmlowp |
560 | |
561 | #endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ |
562 | |