1// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
16
17#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
18#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
19
20#include "../internal/kernel_default.h"
21#include "../public/map.h"
22#include "../public/output_stages.h"
23#include "multi_thread_gemm.h"
24
25namespace gemmlowp {
26
27template <typename T>
28struct TransposeImpl {
29 typedef T DstType;
30 static T Run(const T& t) { return t; }
31};
32
33template <typename T>
34using TransposeType = typename TransposeImpl<T>::DstType;
35
36template <typename T>
37TransposeType<T> Transpose(const T& t) {
38 return TransposeImpl<T>::Run(t);
39}
40
41template <MapOrder Order>
42struct TransposeMapOrder {
43 static constexpr MapOrder Value =
44 Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
45};
46
47template <VectorShape Shape>
48struct TransposeVectorShape {
49 static constexpr VectorShape Value =
50 Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
51};
52
53template <typename Scalar, VectorShape Shape>
54struct TransposeImpl<VectorMap<Scalar, Shape>> {
55 typedef VectorMap<Scalar, Shape> SrcType;
56 static constexpr VectorShape TransposedShape =
57 TransposeVectorShape<Shape>::Value;
58 typedef VectorMap<Scalar, TransposedShape> DstType;
59 static DstType Run(const SrcType& src) {
60 return DstType(src.data(), src.size());
61 }
62};
63
64template <typename Scalar, MapOrder Order>
65struct TransposeImpl<MatrixMap<Scalar, Order>> {
66 typedef MatrixMap<Scalar, Order> SrcType;
67 static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
68 typedef MatrixMap<Scalar, TransposedOrder> DstType;
69 static DstType Run(const SrcType& src) {
70 return DstType(src.data(), src.cols(), src.rows(), src.stride());
71 }
72};
73
74template <VectorShape Shape>
75struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
76 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
77 static constexpr VectorShape TransposedShape =
78 TransposeVectorShape<Shape>::Value;
79 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
80 static DstType Run(const SrcType& src) {
81 DstType dst;
82 dst.result_shift = src.result_shift;
83 dst.result_offset = Transpose(src.result_offset);
84 dst.result_mult_int = Transpose(src.result_mult_int);
85 return dst;
86 }
87};
88
89template <VectorShape Shape>
90struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
91 typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
92 static constexpr VectorShape TransposedShape =
93 TransposeVectorShape<Shape>::Value;
94 typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
95 DstType;
96 static DstType Run(const SrcType& src) {
97 DstType dst;
98 dst.result_fixedpoint_multiplier =
99 Transpose(src.result_fixedpoint_multiplier);
100 dst.result_exponent = Transpose(src.result_exponent);
101 dst.result_offset_after_shift = src.result_offset_after_shift;
102 return dst;
103 }
104};
105
106template <typename VectorMapType>
107struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
108 typedef OutputStageBiasAddition<VectorMapType> SrcType;
109 typedef TransposeType<VectorMapType> TransposedVectorMapType;
110 typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
111 static DstType Run(const SrcType& src) {
112 DstType dst;
113 dst.bias_vector = Transpose(src.bias_vector);
114 return dst;
115 }
116};
117
118// TODO(benoitjacob) - does anyone understand C++ variadic templates?
119// How to use them to implement TransposeTuple? Note: there are lots
120// of answers on StackOverflow but they seem to all involve either
121// C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
122inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
123
124template <typename T0>
125std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
126 return std::make_tuple(Transpose(std::get<0>(t)));
127}
128
129template <typename T0, typename T1>
130std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
131 const std::tuple<T0, T1>& t) {
132 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
133}
134
135template <typename T0, typename T1, typename T2>
136std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
137TransposeTuple(const std::tuple<T0, T1, T2>& t) {
138 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
139 Transpose(std::get<2>(t)));
140}
141
142template <typename T0, typename T1, typename T2, typename T3>
143std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
144 TransposeType<T3>>
145TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
146 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
147 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
148}
149
150template <typename T0, typename T1, typename T2, typename T3, typename T4>
151std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
152 TransposeType<T3>, TransposeType<T4>>
153TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
154 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
155 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
156 Transpose(std::get<4>(t)));
157}
158
159template <typename T0, typename T1, typename T2, typename T3, typename T4,
160 typename T5>
161std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
162 TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
163TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
164 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
165 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
166 Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
167}
168
169template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
170 MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
171 typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
172 typename GemmContextType>
173void DispatchGemmShape(GemmContextType* context,
174 const MatrixMap<const InputScalar, LhsOrder>& lhs,
175 const MatrixMap<const InputScalar, RhsOrder>& rhs,
176 MatrixMap<OutputScalar, ResultOrder>* result,
177 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
178 const OutputPipelineType& output_pipeline) {
179 assert(lhs.cols() == rhs.rows());
180
181 int rows = result->rows();
182 int cols = result->cols();
183 int depth = lhs.cols();
184
185 if (rows == 0 || cols == 0 || depth == 0) {
186 // Vacuous GEMM, return early to avoid having to deal with
187 // zero sizes below.
188 return;
189 }
190
191 if (rows < cols) {
192 auto transposed_result_map = Transpose(*result);
193 return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
194 context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
195 Transpose(rhs_offset), Transpose(lhs_offset),
196 TransposeTuple(output_pipeline));
197 }
198
199 typedef DefaultKernel<BitDepthParams> Kernel;
200 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
201 BitDepthParams>(context, Kernel(), lhs, rhs, result,
202 lhs_offset, rhs_offset, output_pipeline);
203}
204
205} // end namespace gemmlowp
206
207#endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
208