1// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// pack_SSE.h: optimized SSE specializations of the templates in pack.h.
16
17#ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_
18#define GEMMLOWP_INTERNAL_PACK_SSE_H_
19
20#include <smmintrin.h>
21#include "pack.h"
22
23namespace gemmlowp {
24
25// TODO: Add DepthMajorUint8SideMap
26
27typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
28 WidthMajorUint8SideMap;
29
30template <int Cells>
31using WidthMajorSideFormatNCells4x2 =
32 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
33
34template <int Cells>
35class PackingRegisterBlock<
36 WidthMajorUint8SideMap,
37 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > >
38 : public PackingRegisterBlockBase<
39 WidthMajorUint8SideMap,
40 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > {
41 public:
42 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
43 typedef typename KernelSideFormat::Cell CellFormat;
44 static constexpr int kCells = KernelSideFormat::kCells;
45 static constexpr int kCellWidth = CellFormat::kWidth;
46 static constexpr int kKernelWidth = CellFormat::kWidth * kCells;
47 static constexpr int kCellDepth = CellFormat::kDepth;
48 static constexpr int kCellSize = CellFormat::kSize;
49
50 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
51 std::uint8_t* dst_ptr = dst->current_data();
52 const int width_stride = this->complete_src_.width_stride();
53 int depth_step = 8;
54
55 __m128i one = _mm_set1_epi16(1);
56 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
57 cell_start_depth += depth_step) {
58 for (int cell_start_width = 0; cell_start_width < kKernelWidth;
59 cell_start_width += kCellWidth) {
60 std::int32_t* cell_sums_of_each_slice_ptr =
61 dst->sums_of_each_slice() + start_width + cell_start_width;
62 const std::uint8_t* src_data =
63 this->complete_src_.data(cell_start_width, cell_start_depth);
64
65 __m128i xmm1 =
66 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0]));
67 __m128i xmm2 = _mm_loadl_epi64(
68 reinterpret_cast<const __m128i*>(&src_data[1 * width_stride]));
69 __m128i xmm3 = _mm_loadl_epi64(
70 reinterpret_cast<const __m128i*>(&src_data[2 * width_stride]));
71 __m128i xmm4 = _mm_loadl_epi64(
72 reinterpret_cast<const __m128i*>(&src_data[3 * width_stride]));
73
74 __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
75 __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
76
77 __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
78 __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
79
80 __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
81 __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
82
83 _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9);
84 _mm_storel_epi64(
85 reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10);
86
87 __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
88 __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
89
90 _mm_storel_epi64(
91 reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]),
92 xmm11);
93 _mm_storel_epi64(
94 reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]),
95 xmm12);
96
97 xmm1 = _mm_cvtepu8_epi16(xmm9);
98 xmm2 = _mm_madd_epi16(xmm1, one);
99 __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
100 reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0]));
101 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
102
103 xmm1 = _mm_cvtepu8_epi16(xmm10);
104 xmm2 = _mm_madd_epi16(xmm1, one);
105 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
106
107 xmm1 = _mm_cvtepu8_epi16(xmm11);
108 xmm2 = _mm_madd_epi16(xmm1, one);
109 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
110
111 xmm1 = _mm_cvtepu8_epi16(xmm12);
112 xmm2 = _mm_madd_epi16(xmm1, one);
113 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
114
115 _mm_storeu_si128(
116 reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]),
117 sums_of_each_slice_xmm);
118 dst_ptr += kCellSize;
119 }
120 dst_ptr += 3 * kCellSize * kCells;
121 }
122 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
123 }
124};
125
126} // namespace gemmlowp
127
128#endif // GEMMLOWP_INTERNAL_PACK_SSE_H_
129