1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implements a quantized eight-bit version of the matmul operation. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK |
21 | #include "public/gemmlowp.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/op_requires.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/kernels/meta_support.h" |
27 | #include "tensorflow/core/kernels/quantization_utils.h" |
28 | #include "tensorflow/core/kernels/reference_gemm.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/platform/errors.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // We have to break this out as a separate function because there are multiple |
35 | // combinations of transpose attributes we need to support, and they have to be |
36 | // compile-time constants to work with the templates used internally. |
37 | template <bool TransposeA, bool TransposeB, bool TransposeC> |
38 | void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data, |
39 | const quint8* b_data, qint32* c_data, int m, int n, int k, |
40 | int offset_a, int offset_b, int lda, int ldb, int ldc) { |
41 | const uint8* a_data_as_uint8 = &(a_data->value); |
42 | const uint8* b_data_as_uint8 = &(b_data->value); |
43 | int32* c_data_as_int32 = &(c_data->value); |
44 | static const gemmlowp::MapOrder ResultOrder = |
45 | !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; |
46 | static const gemmlowp::MapOrder LhsOrder = |
47 | !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; |
48 | static const gemmlowp::MapOrder RhsOrder = |
49 | !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; |
50 | gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k, |
51 | lda); |
52 | gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n, |
53 | ldb); |
54 | gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n, |
55 | ldc); |
56 | const std::tuple<> empty_pipeline = {}; |
57 | auto& worker_threads = |
58 | *(op_context->device()->tensorflow_cpu_worker_threads()); |
59 | TensorflowGemmContext context(worker_threads.num_threads, |
60 | worker_threads.workers); |
61 | gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, |
62 | gemmlowp::DefaultL8R8BitDepthParams>( |
63 | &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline); |
64 | // Since gemmlowp uses assembly to write to the output, msan won't detect |
65 | // the output buffer as written to, so we mark it manually. |
66 | TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32)); |
67 | } |
68 | |
69 | template <class T1, class T2, class Toutput> |
70 | class QuantizedMatMulOp : public OpKernel { |
71 | public: |
72 | explicit QuantizedMatMulOp(OpKernelConstruction* context) |
73 | : OpKernel(context) { |
74 | OP_REQUIRES_OK(context, context->GetAttr("transpose_a" , &transpose_a_)); |
75 | OP_REQUIRES_OK(context, context->GetAttr("transpose_b" , &transpose_b_)); |
76 | } |
77 | |
78 | void Compute(OpKernelContext* context) override { |
79 | const Tensor& a = context->input(0); |
80 | const Tensor& b = context->input(1); |
81 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()), |
82 | errors::InvalidArgument("min_a must be a scalar, but got shape" , |
83 | context->input(2).shape())); |
84 | const float min_a = context->input(2).flat<float>()(0); |
85 | OP_REQUIRES(context, context->input(3).NumElements() == 1, |
86 | errors::InvalidArgument("max_a must be a scalar, but got shape" , |
87 | context->input(3).shape())); |
88 | const float max_a = context->input(3).flat<float>()(0); |
89 | OP_REQUIRES(context, context->input(4).NumElements() == 1, |
90 | errors::InvalidArgument("min_b must be a scalar, but got shape" , |
91 | context->input(4).shape())); |
92 | const float min_b = context->input(4).flat<float>()(0); |
93 | OP_REQUIRES(context, context->input(5).NumElements() == 1, |
94 | errors::InvalidArgument("max_b must be a scalar, but got shape" , |
95 | context->input(5).shape())); |
96 | const float max_b = context->input(5).flat<float>()(0); |
97 | |
98 | // Make sure that we have valid quantization ranges for the input buffers. |
99 | // If the difference between the min and max is negative or zero, it makes |
100 | // it hard to do meaningful intermediate operations on the values. |
101 | OP_REQUIRES(context, (max_a > min_a), |
102 | errors::InvalidArgument("max_a must be larger than min_a." )); |
103 | OP_REQUIRES(context, (max_b > min_b), |
104 | errors::InvalidArgument("max_b must be larger than min_b." )); |
105 | const int32_t offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a); |
106 | const int32_t offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b); |
107 | const int32_t offset_c = 0; |
108 | const int32_t mult_c = 1; |
109 | const int32_t shift_c = 0; |
110 | |
111 | // Check that the dimensions of the two matrices are valid. |
112 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()), |
113 | errors::InvalidArgument("In[0] is not a matrix" )); |
114 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()), |
115 | errors::InvalidArgument("In[1] is not a matrix" )); |
116 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; |
117 | dim_pair[0].first = transpose_a_ ? 0 : 1; |
118 | dim_pair[0].second = transpose_b_ ? 1 : 0; |
119 | |
120 | OP_REQUIRES(context, |
121 | a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), |
122 | errors::InvalidArgument("Matrix size-incompatible: In[0]: " , |
123 | a.shape().DebugString(), |
124 | ", In[1]: " , b.shape().DebugString())); |
125 | |
126 | OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)), |
127 | errors::InvalidArgument("shift_c must be between 0 and 31, " |
128 | "inclusive." )); |
129 | |
130 | int a_dim_remaining = 1 - dim_pair[0].first; |
131 | int b_dim_remaining = 1 - dim_pair[0].second; |
132 | TensorShape out_shape( |
133 | {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); |
134 | Tensor* c = nullptr; |
135 | OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c)); |
136 | CHECK(c); |
137 | |
138 | const T1* a_data = a.flat<T1>().data(); |
139 | const T2* b_data = b.flat<T2>().data(); |
140 | Toutput* c_data = c->flat<Toutput>().data(); |
141 | |
142 | const bool transpose_c = false; |
143 | const size_t m = a.dim_size(a_dim_remaining); |
144 | const size_t n = b.dim_size(b_dim_remaining); |
145 | const size_t k = a.dim_size(dim_pair[0].first); |
146 | const size_t lda = a.dim_size(1); |
147 | const size_t ldb = b.dim_size(1); |
148 | const size_t ldc = n; |
149 | |
150 | if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && |
151 | std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() && |
152 | (offset_c == 0) && (mult_c == 1) && (shift_c == 0) && |
153 | (transpose_c == false) && (k <= 2048)) { |
154 | // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and |
155 | // allows optimized quantized 8bit to 32bit gemm. |
156 | meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data, |
157 | c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc); |
158 | } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && |
159 | std::is_same<Toutput, qint32>() && (offset_c == 0) && |
160 | (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) { |
161 | // The gemmlowp optimized library only works for a particular set of data |
162 | // types, so check if we meet those requirements and fall back to a slower |
163 | // reference implementation if not. |
164 | if (transpose_a_) { |
165 | if (transpose_b_) { |
166 | GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data, |
167 | m, n, k, offset_a, offset_b, lda, |
168 | ldb, ldc); |
169 | } else { |
170 | GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data, |
171 | m, n, k, offset_a, offset_b, lda, |
172 | ldb, ldc); |
173 | } |
174 | } else { |
175 | if (transpose_b_) { |
176 | GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data, |
177 | m, n, k, offset_a, offset_b, lda, |
178 | ldb, ldc); |
179 | } else { |
180 | GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data, |
181 | m, n, k, offset_a, offset_b, |
182 | lda, ldb, ldc); |
183 | } |
184 | } |
185 | } else { |
186 | ReferenceGemm<T1, T2, Toutput>( |
187 | transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a, |
188 | lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc); |
189 | } |
190 | |
191 | float min_c_value; |
192 | float max_c_value; |
193 | QuantizationRangeForMultiplication<T1, T2, Toutput>( |
194 | min_a, max_a, min_b, max_b, &min_c_value, &max_c_value); |
195 | Tensor* c_min = nullptr; |
196 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min)); |
197 | c_min->flat<float>()(0) = min_c_value; |
198 | |
199 | Tensor* c_max = nullptr; |
200 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max)); |
201 | c_max->flat<float>()(0) = max_c_value; |
202 | } |
203 | |
204 | private: |
205 | bool transpose_a_; |
206 | bool transpose_b_; |
207 | }; |
208 | |
209 | REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul" ) |
210 | .Device(DEVICE_CPU) |
211 | .TypeConstraint<quint8>("T1" ) |
212 | .TypeConstraint<quint8>("T2" ) |
213 | .TypeConstraint<qint32>("Toutput" ), |
214 | QuantizedMatMulOp<quint8, quint8, qint32>); |
215 | |
216 | } // namespace tensorflow |
217 | |