1 | // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // pack_SSE.h: optimized SSE specializations of the templates in pack.h. |
16 | |
17 | #ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_ |
18 | #define GEMMLOWP_INTERNAL_PACK_SSE_H_ |
19 | |
20 | #include <smmintrin.h> |
21 | #include "pack.h" |
22 | |
23 | namespace gemmlowp { |
24 | |
25 | // TODO: Add DepthMajorUint8SideMap |
26 | |
27 | typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> |
28 | WidthMajorUint8SideMap; |
29 | |
30 | template <int Cells> |
31 | using WidthMajorSideFormatNCells4x2 = |
32 | KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; |
33 | |
34 | template <int Cells> |
35 | class PackingRegisterBlock< |
36 | WidthMajorUint8SideMap, |
37 | PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > |
38 | : public PackingRegisterBlockBase< |
39 | WidthMajorUint8SideMap, |
40 | PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > { |
41 | public: |
42 | typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; |
43 | typedef typename KernelSideFormat::Cell CellFormat; |
44 | static constexpr int kCells = KernelSideFormat::kCells; |
45 | static constexpr int kCellWidth = CellFormat::kWidth; |
46 | static constexpr int kKernelWidth = CellFormat::kWidth * kCells; |
47 | static constexpr int kCellDepth = CellFormat::kDepth; |
48 | static constexpr int kCellSize = CellFormat::kSize; |
49 | |
50 | void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { |
51 | std::uint8_t* dst_ptr = dst->current_data(); |
52 | const int width_stride = this->complete_src_.width_stride(); |
53 | int depth_step = 8; |
54 | |
55 | __m128i one = _mm_set1_epi16(1); |
56 | for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; |
57 | cell_start_depth += depth_step) { |
58 | for (int cell_start_width = 0; cell_start_width < kKernelWidth; |
59 | cell_start_width += kCellWidth) { |
60 | std::int32_t* cell_sums_of_each_slice_ptr = |
61 | dst->sums_of_each_slice() + start_width + cell_start_width; |
62 | const std::uint8_t* src_data = |
63 | this->complete_src_.data(cell_start_width, cell_start_depth); |
64 | |
65 | __m128i xmm1 = |
66 | _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0])); |
67 | __m128i xmm2 = _mm_loadl_epi64( |
68 | reinterpret_cast<const __m128i*>(&src_data[1 * width_stride])); |
69 | __m128i xmm3 = _mm_loadl_epi64( |
70 | reinterpret_cast<const __m128i*>(&src_data[2 * width_stride])); |
71 | __m128i xmm4 = _mm_loadl_epi64( |
72 | reinterpret_cast<const __m128i*>(&src_data[3 * width_stride])); |
73 | |
74 | __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); |
75 | __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); |
76 | |
77 | __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); |
78 | __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); |
79 | |
80 | __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); |
81 | __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); |
82 | |
83 | _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9); |
84 | _mm_storel_epi64( |
85 | reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10); |
86 | |
87 | __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); |
88 | __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); |
89 | |
90 | _mm_storel_epi64( |
91 | reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]), |
92 | xmm11); |
93 | _mm_storel_epi64( |
94 | reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]), |
95 | xmm12); |
96 | |
97 | xmm1 = _mm_cvtepu8_epi16(xmm9); |
98 | xmm2 = _mm_madd_epi16(xmm1, one); |
99 | __m128i sums_of_each_slice_xmm = _mm_loadu_si128( |
100 | reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0])); |
101 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
102 | |
103 | xmm1 = _mm_cvtepu8_epi16(xmm10); |
104 | xmm2 = _mm_madd_epi16(xmm1, one); |
105 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
106 | |
107 | xmm1 = _mm_cvtepu8_epi16(xmm11); |
108 | xmm2 = _mm_madd_epi16(xmm1, one); |
109 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
110 | |
111 | xmm1 = _mm_cvtepu8_epi16(xmm12); |
112 | xmm2 = _mm_madd_epi16(xmm1, one); |
113 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
114 | |
115 | _mm_storeu_si128( |
116 | reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]), |
117 | sums_of_each_slice_xmm); |
118 | dst_ptr += kCellSize; |
119 | } |
120 | dst_ptr += 3 * kCellSize * kCells; |
121 | } |
122 | dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); |
123 | } |
124 | }; |
125 | |
126 | } // namespace gemmlowp |
127 | |
128 | #endif // GEMMLOWP_INTERNAL_PACK_SSE_H_ |
129 | |