1 | // Copyright 2017 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape |
16 | |
17 | #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ |
18 | #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ |
19 | |
20 | #include "../internal/kernel_default.h" |
21 | #include "../public/map.h" |
22 | #include "../public/output_stages.h" |
23 | #include "multi_thread_gemm.h" |
24 | |
25 | namespace gemmlowp { |
26 | |
27 | template <typename T> |
28 | struct TransposeImpl { |
29 | typedef T DstType; |
30 | static T Run(const T& t) { return t; } |
31 | }; |
32 | |
33 | template <typename T> |
34 | using TransposeType = typename TransposeImpl<T>::DstType; |
35 | |
36 | template <typename T> |
37 | TransposeType<T> Transpose(const T& t) { |
38 | return TransposeImpl<T>::Run(t); |
39 | } |
40 | |
41 | template <MapOrder Order> |
42 | struct TransposeMapOrder { |
43 | static constexpr MapOrder Value = |
44 | Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor; |
45 | }; |
46 | |
47 | template <VectorShape Shape> |
48 | struct TransposeVectorShape { |
49 | static constexpr VectorShape Value = |
50 | Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row; |
51 | }; |
52 | |
53 | template <typename Scalar, VectorShape Shape> |
54 | struct TransposeImpl<VectorMap<Scalar, Shape>> { |
55 | typedef VectorMap<Scalar, Shape> SrcType; |
56 | static constexpr VectorShape TransposedShape = |
57 | TransposeVectorShape<Shape>::Value; |
58 | typedef VectorMap<Scalar, TransposedShape> DstType; |
59 | static DstType Run(const SrcType& src) { |
60 | return DstType(src.data(), src.size()); |
61 | } |
62 | }; |
63 | |
64 | template <typename Scalar, MapOrder Order> |
65 | struct TransposeImpl<MatrixMap<Scalar, Order>> { |
66 | typedef MatrixMap<Scalar, Order> SrcType; |
67 | static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value; |
68 | typedef MatrixMap<Scalar, TransposedOrder> DstType; |
69 | static DstType Run(const SrcType& src) { |
70 | return DstType(src.data(), src.cols(), src.rows(), src.stride()); |
71 | } |
72 | }; |
73 | |
74 | template <VectorShape Shape> |
75 | struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { |
76 | typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType; |
77 | static constexpr VectorShape TransposedShape = |
78 | TransposeVectorShape<Shape>::Value; |
79 | typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType; |
80 | static DstType Run(const SrcType& src) { |
81 | DstType dst; |
82 | dst.result_shift = src.result_shift; |
83 | dst.result_offset = Transpose(src.result_offset); |
84 | dst.result_mult_int = Transpose(src.result_mult_int); |
85 | return dst; |
86 | } |
87 | }; |
88 | |
89 | template <VectorShape Shape> |
90 | struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> { |
91 | typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType; |
92 | static constexpr VectorShape TransposedShape = |
93 | TransposeVectorShape<Shape>::Value; |
94 | typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape> |
95 | DstType; |
96 | static DstType Run(const SrcType& src) { |
97 | DstType dst; |
98 | dst.result_fixedpoint_multiplier = |
99 | Transpose(src.result_fixedpoint_multiplier); |
100 | dst.result_exponent = Transpose(src.result_exponent); |
101 | dst.result_offset_after_shift = src.result_offset_after_shift; |
102 | return dst; |
103 | } |
104 | }; |
105 | |
106 | template <typename VectorMapType> |
107 | struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> { |
108 | typedef OutputStageBiasAddition<VectorMapType> SrcType; |
109 | typedef TransposeType<VectorMapType> TransposedVectorMapType; |
110 | typedef OutputStageBiasAddition<TransposedVectorMapType> DstType; |
111 | static DstType Run(const SrcType& src) { |
112 | DstType dst; |
113 | dst.bias_vector = Transpose(src.bias_vector); |
114 | return dst; |
115 | } |
116 | }; |
117 | |
118 | // TODO(benoitjacob) - does anyone understand C++ variadic templates? |
119 | // How to use them to implement TransposeTuple? Note: there are lots |
120 | // of answers on StackOverflow but they seem to all involve either |
121 | // C++14/C++17 (we can only use C++11) or lots of abstract nonsense. |
122 | inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; } |
123 | |
124 | template <typename T0> |
125 | std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) { |
126 | return std::make_tuple(Transpose(std::get<0>(t))); |
127 | } |
128 | |
129 | template <typename T0, typename T1> |
130 | std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple( |
131 | const std::tuple<T0, T1>& t) { |
132 | return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t))); |
133 | } |
134 | |
135 | template <typename T0, typename T1, typename T2> |
136 | std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>> |
137 | TransposeTuple(const std::tuple<T0, T1, T2>& t) { |
138 | return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), |
139 | Transpose(std::get<2>(t))); |
140 | } |
141 | |
142 | template <typename T0, typename T1, typename T2, typename T3> |
143 | std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, |
144 | TransposeType<T3>> |
145 | TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) { |
146 | return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), |
147 | Transpose(std::get<2>(t)), Transpose(std::get<3>(t))); |
148 | } |
149 | |
150 | template <typename T0, typename T1, typename T2, typename T3, typename T4> |
151 | std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, |
152 | TransposeType<T3>, TransposeType<T4>> |
153 | TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) { |
154 | return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), |
155 | Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), |
156 | Transpose(std::get<4>(t))); |
157 | } |
158 | |
159 | template <typename T0, typename T1, typename T2, typename T3, typename T4, |
160 | typename T5> |
161 | std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, |
162 | TransposeType<T3>, TransposeType<T4>, TransposeType<T5>> |
163 | TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) { |
164 | return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), |
165 | Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), |
166 | Transpose(std::get<4>(t)), Transpose(std::get<5>(t))); |
167 | } |
168 | |
169 | template <typename InputScalar, typename OutputScalar, typename BitDepthParams, |
170 | MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder, |
171 | typename LhsOffset, typename RhsOffset, typename OutputPipelineType, |
172 | typename GemmContextType> |
173 | void DispatchGemmShape(GemmContextType* context, |
174 | const MatrixMap<const InputScalar, LhsOrder>& lhs, |
175 | const MatrixMap<const InputScalar, RhsOrder>& rhs, |
176 | MatrixMap<OutputScalar, ResultOrder>* result, |
177 | const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, |
178 | const OutputPipelineType& output_pipeline) { |
179 | assert(lhs.cols() == rhs.rows()); |
180 | |
181 | int rows = result->rows(); |
182 | int cols = result->cols(); |
183 | int depth = lhs.cols(); |
184 | |
185 | if (rows == 0 || cols == 0 || depth == 0) { |
186 | // Vacuous GEMM, return early to avoid having to deal with |
187 | // zero sizes below. |
188 | return; |
189 | } |
190 | |
191 | if (rows < cols) { |
192 | auto transposed_result_map = Transpose(*result); |
193 | return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>( |
194 | context, Transpose(rhs), Transpose(lhs), &transposed_result_map, |
195 | Transpose(rhs_offset), Transpose(lhs_offset), |
196 | TransposeTuple(output_pipeline)); |
197 | } |
198 | |
199 | typedef DefaultKernel<BitDepthParams> Kernel; |
200 | MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar, |
201 | BitDepthParams>(context, Kernel(), lhs, rhs, result, |
202 | lhs_offset, rhs_offset, output_pipeline); |
203 | } |
204 | |
205 | } // end namespace gemmlowp |
206 | |
207 | #endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ |
208 | |