1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/eigen_support.h"
16
17#include <functional>
18#include <memory>
19#include <utility>
20
21#include "tensorflow/lite/arena_planner.h"
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
24#include "tensorflow/lite/kernels/op_macros.h"
25
26namespace tflite {
27namespace eigen_support {
28namespace {
29
30// For legacy reasons, we use 4 threads by default unless the thread count is
31// explicitly specified by the context.
32const int kDefaultNumThreadpoolThreads = 4;
33
34bool IsValidNumThreads(int num_threads) { return num_threads >= -1; }
35int GetNumThreads(int num_threads) {
36 return num_threads > -1 ? num_threads : kDefaultNumThreadpoolThreads;
37}
38
39#ifndef EIGEN_DONT_ALIGN
40// Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on
41// hardware architecture and build configurations.
42// If the static assertion fails, try to increase `kDefaultTensorAlignment` to
43// in `arena_planner.h` to 32 or 64.
44static_assert(
45 kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0,
46 "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement.");
47#endif // EIGEN_DONT_ALIGN
48
49// Helper routine for updating the global Eigen thread count used for OpenMP.
50void SetEigenNbThreads(int threads) {
51#if defined(EIGEN_HAS_OPENMP)
52 // The global Eigen thread count is only used when OpenMP is enabled. As this
53 // call causes problems with tsan, make it only when OpenMP is available.
54 Eigen::setNbThreads(threads);
55#endif // defined(EIGEN_HAS_OPENMP)
56}
57
58// We have a single global threadpool for all convolution operations. This means
59// that inferences started from different threads may block each other, but
60// since the underlying resource of CPU cores should be consumed by the
61// operations anyway, it shouldn't affect overall performance. Note that we
62// also avoid ThreadPool creation if the target thread count is 1, avoiding
63// unnecessary overhead, and more closely mimicking Gemmlowp threadpool
64// behavior.
65class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
66 public:
67 // Takes ownership of 'pool'
68 explicit EigenThreadPoolWrapper(int num_threads) {
69 // Avoid creating any threads for the single-threaded case.
70 if (num_threads > 1) {
71 pool_ = std::make_unique<Eigen::ThreadPool>(num_threads);
72 }
73 }
74 ~EigenThreadPoolWrapper() override {}
75
76 void Schedule(std::function<void()> fn) override {
77 if (pool_) {
78 pool_->Schedule(std::move(fn));
79 } else {
80 fn();
81 }
82 }
83 int NumThreads() const override { return pool_ ? pool_->NumThreads() : 1; }
84 int CurrentThreadId() const override {
85 return pool_ ? pool_->CurrentThreadId() : 0;
86 }
87
88 private:
89 // May be null if num_threads <= 1.
90 std::unique_ptr<Eigen::ThreadPool> pool_;
91};
92
93// Utility class for lazily creating an Eigen thread pool/device only when used.
94class LazyEigenThreadPoolHolder {
95 public:
96 explicit LazyEigenThreadPoolHolder(int num_threads) {
97 SetNumThreads(num_threads);
98 }
99
100 // Gets the ThreadPoolDevice, creating if necessary.
101 const Eigen::ThreadPoolDevice* GetThreadPoolDevice() {
102 if (!device_) {
103 thread_pool_wrapper_ =
104 std::make_unique<EigenThreadPoolWrapper>(target_num_threads_);
105 device_ = std::make_unique<Eigen::ThreadPoolDevice>(
106 thread_pool_wrapper_.get(), target_num_threads_);
107 }
108 return device_.get();
109 }
110
111 // Updates the thread count, invalidating the ThreadPoolDevice if necessary.
112 void SetNumThreads(int num_threads) {
113 const int target_num_threads = GetNumThreads(num_threads);
114 if (target_num_threads_ != target_num_threads) {
115 target_num_threads_ = target_num_threads;
116 // As the device references the thread pool wrapper, destroy it first.
117 device_.reset();
118 thread_pool_wrapper_.reset();
119 }
120 }
121
122 private:
123 int target_num_threads_ = kDefaultNumThreadpoolThreads;
124 // Both device_ and thread_pool_wrapper_ are lazily created.
125 std::unique_ptr<Eigen::ThreadPoolDevice> device_;
126 std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_;
127};
128
129struct RefCountedEigenContext : public TfLiteExternalContext {
130 std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder;
131 int num_references = 0;
132};
133
134RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
135 return reinterpret_cast<RefCountedEigenContext*>(
136 context->GetExternalContext(context, kTfLiteEigenContext));
137}
138
139TfLiteStatus Refresh(TfLiteContext* context) {
140 if (IsValidNumThreads(context->recommended_num_threads)) {
141 SetEigenNbThreads(GetNumThreads(context->recommended_num_threads));
142 }
143
144 auto* ptr = GetEigenContext(context);
145 if (ptr != nullptr) {
146 ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads);
147 }
148
149 return kTfLiteOk;
150}
151
152} // namespace
153
154void IncrementUsageCounter(TfLiteContext* context) {
155 auto* ptr = GetEigenContext(context);
156 if (ptr == nullptr) {
157 if (IsValidNumThreads(context->recommended_num_threads)) {
158 SetEigenNbThreads(context->recommended_num_threads);
159 }
160 ptr = new RefCountedEigenContext;
161 ptr->type = kTfLiteEigenContext;
162 ptr->Refresh = Refresh;
163 ptr->thread_pool_holder = std::make_unique<LazyEigenThreadPoolHolder>(
164 context->recommended_num_threads);
165 ptr->num_references = 0;
166 context->SetExternalContext(context, kTfLiteEigenContext, ptr);
167 }
168 ptr->num_references++;
169}
170
171void DecrementUsageCounter(TfLiteContext* context) {
172 auto* ptr = GetEigenContext(context);
173 if (ptr == nullptr) {
174 TF_LITE_FATAL(
175 "Call to DecrementUsageCounter() not preceded by "
176 "IncrementUsageCounter()");
177 }
178 if (--ptr->num_references == 0) {
179 delete ptr;
180 context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
181 }
182}
183
184const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
185 auto* ptr = GetEigenContext(context);
186 if (ptr == nullptr) {
187 TF_LITE_FATAL(
188 "Call to GetFromContext() not preceded by IncrementUsageCounter()");
189 }
190 return ptr->thread_pool_holder->GetThreadPoolDevice();
191}
192
193} // namespace eigen_support
194} // namespace tflite
195