1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements a quantized eight-bit version of the matmul operation.
17
18#define EIGEN_USE_THREADS
19
20#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
21#include "public/gemmlowp.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/kernels/meta_support.h"
27#include "tensorflow/core/kernels/quantization_utils.h"
28#include "tensorflow/core/kernels/reference_gemm.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/platform/errors.h"
31
32namespace tensorflow {
33
34// We have to break this out as a separate function because there are multiple
35// combinations of transpose attributes we need to support, and they have to be
36// compile-time constants to work with the templates used internally.
37template <bool TransposeA, bool TransposeB, bool TransposeC>
38void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
39 const quint8* b_data, qint32* c_data, int m, int n, int k,
40 int offset_a, int offset_b, int lda, int ldb, int ldc) {
41 const uint8* a_data_as_uint8 = &(a_data->value);
42 const uint8* b_data_as_uint8 = &(b_data->value);
43 int32* c_data_as_int32 = &(c_data->value);
44 static const gemmlowp::MapOrder ResultOrder =
45 !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
46 static const gemmlowp::MapOrder LhsOrder =
47 !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
48 static const gemmlowp::MapOrder RhsOrder =
49 !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
50 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
51 lda);
52 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
53 ldb);
54 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
55 ldc);
56 const std::tuple<> empty_pipeline = {};
57 auto& worker_threads =
58 *(op_context->device()->tensorflow_cpu_worker_threads());
59 TensorflowGemmContext context(worker_threads.num_threads,
60 worker_threads.workers);
61 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
62 gemmlowp::DefaultL8R8BitDepthParams>(
63 &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
64 // Since gemmlowp uses assembly to write to the output, msan won't detect
65 // the output buffer as written to, so we mark it manually.
66 TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32));
67}
68
69template <class T1, class T2, class Toutput>
70class QuantizedMatMulOp : public OpKernel {
71 public:
72 explicit QuantizedMatMulOp(OpKernelConstruction* context)
73 : OpKernel(context) {
74 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
75 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
76 }
77
78 void Compute(OpKernelContext* context) override {
79 const Tensor& a = context->input(0);
80 const Tensor& b = context->input(1);
81 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()),
82 errors::InvalidArgument("min_a must be a scalar, but got shape",
83 context->input(2).shape()));
84 const float min_a = context->input(2).flat<float>()(0);
85 OP_REQUIRES(context, context->input(3).NumElements() == 1,
86 errors::InvalidArgument("max_a must be a scalar, but got shape",
87 context->input(3).shape()));
88 const float max_a = context->input(3).flat<float>()(0);
89 OP_REQUIRES(context, context->input(4).NumElements() == 1,
90 errors::InvalidArgument("min_b must be a scalar, but got shape",
91 context->input(4).shape()));
92 const float min_b = context->input(4).flat<float>()(0);
93 OP_REQUIRES(context, context->input(5).NumElements() == 1,
94 errors::InvalidArgument("max_b must be a scalar, but got shape",
95 context->input(5).shape()));
96 const float max_b = context->input(5).flat<float>()(0);
97
98 // Make sure that we have valid quantization ranges for the input buffers.
99 // If the difference between the min and max is negative or zero, it makes
100 // it hard to do meaningful intermediate operations on the values.
101 OP_REQUIRES(context, (max_a > min_a),
102 errors::InvalidArgument("max_a must be larger than min_a."));
103 OP_REQUIRES(context, (max_b > min_b),
104 errors::InvalidArgument("max_b must be larger than min_b."));
105 const int32_t offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a);
106 const int32_t offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b);
107 const int32_t offset_c = 0;
108 const int32_t mult_c = 1;
109 const int32_t shift_c = 0;
110
111 // Check that the dimensions of the two matrices are valid.
112 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
113 errors::InvalidArgument("In[0] is not a matrix"));
114 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
115 errors::InvalidArgument("In[1] is not a matrix"));
116 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
117 dim_pair[0].first = transpose_a_ ? 0 : 1;
118 dim_pair[0].second = transpose_b_ ? 1 : 0;
119
120 OP_REQUIRES(context,
121 a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
122 errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
123 a.shape().DebugString(),
124 ", In[1]: ", b.shape().DebugString()));
125
126 OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)),
127 errors::InvalidArgument("shift_c must be between 0 and 31, "
128 "inclusive."));
129
130 int a_dim_remaining = 1 - dim_pair[0].first;
131 int b_dim_remaining = 1 - dim_pair[0].second;
132 TensorShape out_shape(
133 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
134 Tensor* c = nullptr;
135 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
136 CHECK(c);
137
138 const T1* a_data = a.flat<T1>().data();
139 const T2* b_data = b.flat<T2>().data();
140 Toutput* c_data = c->flat<Toutput>().data();
141
142 const bool transpose_c = false;
143 const size_t m = a.dim_size(a_dim_remaining);
144 const size_t n = b.dim_size(b_dim_remaining);
145 const size_t k = a.dim_size(dim_pair[0].first);
146 const size_t lda = a.dim_size(1);
147 const size_t ldb = b.dim_size(1);
148 const size_t ldc = n;
149
150 if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
151 std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() &&
152 (offset_c == 0) && (mult_c == 1) && (shift_c == 0) &&
153 (transpose_c == false) && (k <= 2048)) {
154 // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and
155 // allows optimized quantized 8bit to 32bit gemm.
156 meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data,
157 c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc);
158 } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
159 std::is_same<Toutput, qint32>() && (offset_c == 0) &&
160 (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) {
161 // The gemmlowp optimized library only works for a particular set of data
162 // types, so check if we meet those requirements and fall back to a slower
163 // reference implementation if not.
164 if (transpose_a_) {
165 if (transpose_b_) {
166 GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
167 m, n, k, offset_a, offset_b, lda,
168 ldb, ldc);
169 } else {
170 GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
171 m, n, k, offset_a, offset_b, lda,
172 ldb, ldc);
173 }
174 } else {
175 if (transpose_b_) {
176 GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
177 m, n, k, offset_a, offset_b, lda,
178 ldb, ldc);
179 } else {
180 GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
181 m, n, k, offset_a, offset_b,
182 lda, ldb, ldc);
183 }
184 }
185 } else {
186 ReferenceGemm<T1, T2, Toutput>(
187 transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a,
188 lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc);
189 }
190
191 float min_c_value;
192 float max_c_value;
193 QuantizationRangeForMultiplication<T1, T2, Toutput>(
194 min_a, max_a, min_b, max_b, &min_c_value, &max_c_value);
195 Tensor* c_min = nullptr;
196 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min));
197 c_min->flat<float>()(0) = min_c_value;
198
199 Tensor* c_max = nullptr;
200 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max));
201 c_max->flat<float>()(0) = max_c_value;
202 }
203
204 private:
205 bool transpose_a_;
206 bool transpose_b_;
207};
208
209REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul")
210 .Device(DEVICE_CPU)
211 .TypeConstraint<quint8>("T1")
212 .TypeConstraint<quint8>("T2")
213 .TypeConstraint<qint32>("Toutput"),
214 QuantizedMatMulOp<quint8, quint8, qint32>);
215
216} // namespace tensorflow
217