1// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef GEMMLOWP_META_BASE_H_
16#define GEMMLOWP_META_BASE_H_
17
18#include <cassert>
19#include <cstdint>
20
21#include "../internal/common.h"
22
23namespace gemmlowp {
24namespace meta {
25
26template <int align>
27inline int AlignTo(int value) {
28 return ((value + align - 1) / align) * align;
29}
30
31inline int AlignTo(int align, int value) {
32 return ((value + align - 1) / align) * align;
33}
34
35template <typename Kernel_, typename OutputStream_>
36struct FusedKernelParams {
37 public:
38 typedef Kernel_ Kernel;
39 typedef OutputStream_ OutputStream;
40
41 Kernel kernel;
42 OutputStream output_stream;
43};
44
45template <typename InType_, typename OutType_, typename LeftStream_,
46 typename RightStream_, typename Kernel_, typename OutputStream_>
47struct GemmParams {
48 public:
49 typedef InType_ InType;
50 typedef OutType_ OutType;
51 typedef LeftStream_ LeftStream;
52 typedef RightStream_ RightStream;
53 typedef Kernel_ Kernel;
54 typedef OutputStream_ OutputStream;
55
56 typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;
57
58 // Common parameters.
59
60 int m;
61 int n;
62 int k;
63
64 const InType* lhs;
65 const InType* rhs;
66 OutType* result;
67 std::uint8_t* scratch;
68
69 // Specialized parameters.
70
71 LeftStream left_stream;
72 RightStream right_stream;
73 FusedKernel fused_kernel;
74};
75
76template <typename InType, int lanes_count, int pack_size, int leftovers,
77 typename StreamParams>
78class Stream {
79 public:
80 static void Pack(const InType* in, const StreamParams& params, InType* out);
81
82 static int UnpackedAdvance(const StreamParams& params);
83
84 static int PackedAdvance(const StreamParams& params);
85
86 static int UnpackedStride(const StreamParams& params);
87
88 static int PackedStride(const StreamParams& params);
89};
90
91template <typename InType, typename StreamType>
92class StreamUtil {
93 public:
94 static const InType* Offset(const StreamType& params, const InType* source,
95 int offset_stride, int offset_advance);
96
97 static int Scratch(const StreamType& params, int lanes);
98};
99
100template <typename InType, typename OutType, typename Kernel,
101 typename OutputStream, int kernel_m, int kernel_n, int pack_size>
102class MulKernel {
103 public:
104 static void Multiply(const InType* lhs, const InType* rhs,
105 const FusedKernelParams<Kernel, OutputStream>& params,
106 OutType* result);
107};
108
109template <typename InType_, typename OutType_, typename Kernel_>
110struct Transform1DParams {
111 typedef InType_ InType;
112 typedef OutType_ OutType;
113 typedef Kernel_ Kernel;
114
115 const InType* input;
116 OutType* output;
117 std::uint8_t* scratch;
118
119 Kernel kernel;
120};
121
122template <typename InType, typename OutType, typename Kernel, int kernel_size,
123 int leftovers>
124class Transform1DKernel {
125 public:
126 static void Transform(const InType* input, const Kernel& params,
127 OutType* output);
128};
129
130template <typename InType, typename OutType, typename Transform>
131class Transform1DUtil {
132 public:
133 static int EstimateComputeCost(const Transform& params);
134
135 static const InType* OffsetInput(const Transform& params, const InType* input,
136 int offset);
137
138 static OutType* OffsetOutput(const Transform& params, OutType* output,
139 int offset);
140};
141
142} // namespace meta
143} // namespace gemmlowp
144
145#endif // GEMMLOWP_META_BASE_H_
146