1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/meta_support.h"
19
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/resource_mgr.h"
22#include "tensorflow/core/kernels/quantization_utils.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/mutex.h"
25
26#if defined(GEMMLOWP_NEON_32) && !defined(TENSORFLOW_DISABLE_META) && \
27 !defined(__APPLE__)
28#define TENSORFLOW_USE_META (1)
29#endif
30
31namespace tensorflow {
32namespace meta {
33
34namespace {
35
36int g_num_threads = 0;
37bool g_enabled = true;
38bool g_use_local_context = false;
39
40#ifdef TENSORFLOW_USE_META
41
42const int kAlignment = 32;
43const int kScratchSize = 2048 * 1024 + kAlignment;
44
45class Scratch : public ResourceBase {
46 public:
47 Scratch() : scratch_(new uint8_t[kScratchSize]) {
48 // Make sure scratch is aligned to 32 bytes. Scratch object owns the
49 // scratch buffer.
50 scratch_32_aligned_ =
51 scratch_.get() + kAlignment -
52 (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
53 }
54
55 uint8_t* buffer() { return scratch_32_aligned_; }
56
57 string DebugString() const override { return "MetaGemmScratchResource"; }
58
59 private:
60 std::unique_ptr<uint8_t> scratch_;
61 uint8_t* scratch_32_aligned_;
62};
63
64uint8_t* GetScratch(OpKernelContext* context) {
65 Scratch* scratch = nullptr;
66 std::function<Status(Scratch**)> creator = [](Scratch** resource) {
67 *resource = new Scratch();
68 return OkStatus();
69 };
70 Status s = context->resource_manager()->LookupOrCreate(
71 "MetaGemm", "ScratchBuffer", &scratch, creator);
72 if (!s.ok()) {
73 context->CtxFailureWithWarning(s);
74 return nullptr;
75 }
76 return scratch->buffer();
77}
78
79gemmlowp::WorkersPool* GetWorkersPool() {
80 static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
81 return pool;
82}
83
84mutex& GetMutex() {
85 static mutex mu(LINKER_INITIALIZED);
86 return mu;
87}
88
89int GetWorkersCount(OpKernelContext* tf_context) {
90 if (g_num_threads == 0) {
91 return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
92 }
93 return g_num_threads;
94}
95
96typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
97
98template <typename Context, typename Params>
99void MultiThreadGemm(Context* context, const Params& params) {
100 if (params.m <= 4) {
101 gemmlowp::meta::MultiThreadGemm<
102 Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
103 8, 8>(context, params);
104 } else {
105 if (params.m >= params.n) {
106 gemmlowp::meta::MultiThreadGemm<
107 Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
108 2, 4, 8>(context, params);
109 } else {
110 gemmlowp::meta::MultiThreadGemm<
111 Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
112 2, 4, 8>(context, params);
113 }
114 }
115}
116
117template <typename LeftStream, typename RightStream>
118void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
119 const quint8* b_data, qint32* c_data, int m, int n,
120 int k, int offset_a, int offset_b, int lda, int ldb,
121 int ldc) {
122 typedef gemmlowp::meta::GemmParams<
123 uint8_t, int32_t, LeftStream, RightStream,
124 gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
125 gemmlowp::meta::RowMajor>
126 Params;
127 Params params;
128
129 params.m = m;
130 params.n = n;
131 params.k = k;
132
133 params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
134 params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
135 params.result = reinterpret_cast<int32_t*>(&(c_data->value));
136 params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
137
138 params.left_stream.count = k;
139 params.left_stream.stride = lda;
140 params.left_stream.multiplicative_sum_offset = offset_b;
141 params.left_stream.additive_sum_offset = k * offset_a * offset_b;
142
143 params.right_stream.count = k;
144 params.right_stream.stride = ldb;
145 params.right_stream.multiplicative_sum_offset = offset_a;
146 params.right_stream.additive_sum_offset = 0;
147
148 params.fused_kernel.kernel.count = k;
149 params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
150
151 if (g_use_local_context) {
152 LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
153 MultiThreadGemm<LocalContext, Params>(&local_context, params);
154 } else {
155 auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
156 TensorflowGemmContext context(workers.num_threads, workers.workers);
157 MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
158 }
159}
160
161template <typename Params, int kernel_size>
162void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
163 if (g_use_local_context) {
164 LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
165 gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
166 &local_context, params);
167 } else {
168 auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
169 TensorflowGemmContext context(workers.num_threads, workers.workers);
170 gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
171 kernel_size>(&context, params);
172 }
173}
174
175template <typename QuantizedType>
176double CalculateRangeScale(float min, float max) {
177 const int bits = sizeof(QuantizedType) * 8;
178 return static_cast<double>(max - min) /
179 ((static_cast<int64_t>(1) << bits) - 1);
180}
181
182template <typename QuantizedType>
183double CalculateOneOverRangeScale(float min, float max) {
184 if (min == max) {
185 return 0.0;
186 }
187 const int bits = sizeof(QuantizedType) * 8;
188 return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
189 (max - min);
190}
191
192#endif // TENSORFLOW_USE_META
193
194} // namespace
195
196void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
197
198int GetNumThreads() { return g_num_threads; }
199
200void SetUseLocalContext(bool use_local_context) {
201 g_use_local_context = use_local_context;
202}
203
204bool GetUseLocalContext() { return g_use_local_context; }
205
206bool IsSupported() {
207#if defined(TENSORFLOW_USE_META)
208 return true;
209#else
210 return false;
211#endif
212}
213
214bool IsEnabled() { return g_enabled; }
215
216void SetEnabled(bool enabled) { g_enabled = enabled; }
217
218bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
219
220void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
221 bool transpose_b, const quint8* a_data, const quint8* b_data,
222 qint32* c_data, int m, int n, int k, int offset_a,
223 int offset_b, int lda, int ldb, int ldc) {
224#ifdef TENSORFLOW_USE_META
225 mutex_lock library_lock(GetMutex());
226 if (transpose_a) {
227 if (transpose_b) {
228 QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
229 gemmlowp::meta::RowMajorWithSum>(
230 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
231 ldb, ldc);
232 } else {
233 QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
234 gemmlowp::meta::ColumnMajorWithSum>(
235 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
236 ldb, ldc);
237 }
238 } else {
239 if (transpose_b) {
240 QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
241 gemmlowp::meta::RowMajorWithSum>(
242 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
243 ldb, ldc);
244 } else {
245 QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
246 gemmlowp::meta::ColumnMajorWithSum>(
247 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
248 ldb, ldc);
249 }
250 }
251#else
252 LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
253#endif
254}
255
256void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
257 float input_min, float input_max, float output_min,
258 float output_max, quint8* output) {
259#ifdef TENSORFLOW_USE_META
260 mutex_lock library_lock(GetMutex());
261 typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
262 gemmlowp::meta::Requantize>
263 Params;
264
265 Params params;
266 params.input = reinterpret_cast<const int32_t*>(input);
267 params.output = reinterpret_cast<uint8_t*>(output);
268 params.kernel.count = count;
269 params.kernel.input_range_min = input_min;
270 params.kernel.output_range_min = output_min;
271 params.kernel.input_range_scale =
272 CalculateRangeScale<int32_t>(input_min, input_max);
273 params.kernel.one_over_output_range_scale =
274 CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
275 params.kernel.input_range_offset =
276 static_cast<float>(std::numeric_limits<int32_t>::lowest());
277 params.kernel.output_range_offset =
278 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
279
280#if defined(GEMMLOWP_NEON_32)
281 // After adding the output_range_offset the value is cast from float to uint.
282 // The float to int/uint cast in 32bit arm uses round toward 0. To keep the
283 // rounding consistent with Eigen, which uses round toward closest, we can
284 // add 0.5f and exploit the fact that we only operate on non negative values.
285 // TODO(maciekc): fix the actual kernel in gemmlowp/meta
286 params.kernel.output_range_offset += 0.5f;
287#endif
288
289 MultiThreadTransform1D<Params, 16>(tf_context, params);
290#else
291 LOG(FATAL) << "Requantize: Meta fastpath not supported.";
292#endif
293}
294
295void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
296 float range_min, float range_max, float* output) {
297#ifdef TENSORFLOW_USE_META
298 mutex_lock library_lock(GetMutex());
299 typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
300 gemmlowp::meta::Dequantize>
301 Params;
302
303 Params params;
304 params.input = reinterpret_cast<const uint8_t*>(input);
305 params.output = reinterpret_cast<float*>(output);
306 params.kernel.count = count;
307 params.kernel.range_min = range_min;
308 params.kernel.range_scale =
309 CalculateRangeScale<uint8_t>(range_min, range_max);
310 params.kernel.range_offset =
311 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
312
313 MultiThreadTransform1D<Params, 16>(tf_context, params);
314#else
315 LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
316#endif
317}
318
319void Quantize(OpKernelContext* tf_context, const float* input, int count,
320 float range_min, float range_max, quint8* output) {
321#ifdef TENSORFLOW_USE_META
322 mutex_lock library_lock(GetMutex());
323 typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
324 gemmlowp::meta::Quantize>
325 Params;
326
327 Params params;
328 params.input = reinterpret_cast<const float*>(input);
329 params.output = reinterpret_cast<uint8_t*>(output);
330 params.kernel.count = count;
331 params.kernel.range_min = range_min;
332 params.kernel.range_scale =
333 CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
334 params.kernel.range_offset =
335 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
336
337#if defined(GEMMLOWP_NEON_32)
338 // The float to int/uint cast on 32bit arm uses round toward 0. To keep the
339 // rounding consistent with Eigen, which uses round toward closest, we can
340 // add 0.5f and exploit the fact that we only operate on non negative values.
341 // TODO(maciekc): fix the actual kernel in gemmlowp/meta
342 params.kernel.range_offset += 0.5f;
343#endif
344
345 MultiThreadTransform1D<Params, 16>(tf_context, params);
346#else
347 LOG(FATAL) << "Quantize: Meta fastpath not supported.";
348#endif
349}
350
351void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
352 int input_count, const quint8* bias, int bias_count,
353 float input_min, float input_max, float bias_min,
354 float bias_max, float output_min, float output_max,
355 qint32* output) {
356#ifdef TENSORFLOW_USE_META
357 mutex_lock library_lock(GetMutex());
358 typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
359 gemmlowp::meta::BiasAdd<uint8_t>>
360 Params;
361
362 Params params;
363 params.input = reinterpret_cast<const uint8_t*>(input);
364 params.output = reinterpret_cast<int32_t*>(output);
365 params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
366 params.kernel.count = bias_count;
367 params.kernel.rows = input_count / bias_count;
368 params.kernel.input_range_min = input_min;
369 params.kernel.bias_range_min = bias_min;
370 params.kernel.input_range_scale =
371 CalculateRangeScale<uint8_t>(input_min, input_max);
372 params.kernel.bias_range_scale =
373 CalculateRangeScale<uint8_t>(bias_min, bias_max);
374 params.kernel.input_range_offset = 0;
375 params.kernel.bias_range_offset = 0;
376 params.kernel.output_range_min = output_min;
377 params.kernel.one_over_output_range_scale =
378 CalculateOneOverRangeScale<int32_t>(output_min, output_max);
379 params.kernel.output_range_offset =
380 static_cast<float>(std::numeric_limits<int32_t>::lowest());
381
382 // TODO(maciekc): add multithreading to bias add.
383 // Right now this kernel does not support multi threaded execution.
384 gemmlowp::meta::Transform1D<Params, 16>(params);
385#else
386 LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
387#endif
388}
389
390void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
391 quint8 clamp_min, quint8 clamp_max, quint8* output) {
392#ifdef TENSORFLOW_USE_META
393 mutex_lock library_lock(GetMutex());
394 typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
395 gemmlowp::meta::MinMax<uint8_t>>
396 Params;
397
398 Params params;
399 params.input = reinterpret_cast<const uint8_t*>(input);
400 params.output = reinterpret_cast<uint8_t*>(output);
401 params.kernel.count = count;
402 params.kernel.min = clamp_min;
403 params.kernel.max = clamp_max;
404
405 MultiThreadTransform1D<Params, 16>(tf_context, params);
406#else
407 LOG(FATAL) << "Clamp: Meta fastpath not supported.";
408#endif
409}
410
411} // namespace meta
412} // namespace tensorflow
413