1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/eigen_support.h" |
16 | |
17 | #include <functional> |
18 | #include <memory> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/lite/arena_planner.h" |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" |
24 | #include "tensorflow/lite/kernels/op_macros.h" |
25 | |
26 | namespace tflite { |
27 | namespace eigen_support { |
28 | namespace { |
29 | |
30 | // For legacy reasons, we use 4 threads by default unless the thread count is |
31 | // explicitly specified by the context. |
32 | const int kDefaultNumThreadpoolThreads = 4; |
33 | |
34 | bool IsValidNumThreads(int num_threads) { return num_threads >= -1; } |
35 | int GetNumThreads(int num_threads) { |
36 | return num_threads > -1 ? num_threads : kDefaultNumThreadpoolThreads; |
37 | } |
38 | |
39 | #ifndef EIGEN_DONT_ALIGN |
40 | // Eigen may require buffers to be aligned to 16, 32 or 64 bytes depending on |
41 | // hardware architecture and build configurations. |
42 | // If the static assertion fails, try to increase `kDefaultTensorAlignment` to |
43 | // in `arena_planner.h` to 32 or 64. |
44 | static_assert( |
45 | kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0, |
46 | "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement." ); |
47 | #endif // EIGEN_DONT_ALIGN |
48 | |
49 | // Helper routine for updating the global Eigen thread count used for OpenMP. |
50 | void SetEigenNbThreads(int threads) { |
51 | #if defined(EIGEN_HAS_OPENMP) |
52 | // The global Eigen thread count is only used when OpenMP is enabled. As this |
53 | // call causes problems with tsan, make it only when OpenMP is available. |
54 | Eigen::setNbThreads(threads); |
55 | #endif // defined(EIGEN_HAS_OPENMP) |
56 | } |
57 | |
58 | // We have a single global threadpool for all convolution operations. This means |
59 | // that inferences started from different threads may block each other, but |
60 | // since the underlying resource of CPU cores should be consumed by the |
61 | // operations anyway, it shouldn't affect overall performance. Note that we |
62 | // also avoid ThreadPool creation if the target thread count is 1, avoiding |
63 | // unnecessary overhead, and more closely mimicking Gemmlowp threadpool |
64 | // behavior. |
65 | class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { |
66 | public: |
67 | // Takes ownership of 'pool' |
68 | explicit EigenThreadPoolWrapper(int num_threads) { |
69 | // Avoid creating any threads for the single-threaded case. |
70 | if (num_threads > 1) { |
71 | pool_ = std::make_unique<Eigen::ThreadPool>(num_threads); |
72 | } |
73 | } |
74 | ~EigenThreadPoolWrapper() override {} |
75 | |
76 | void Schedule(std::function<void()> fn) override { |
77 | if (pool_) { |
78 | pool_->Schedule(std::move(fn)); |
79 | } else { |
80 | fn(); |
81 | } |
82 | } |
83 | int NumThreads() const override { return pool_ ? pool_->NumThreads() : 1; } |
84 | int CurrentThreadId() const override { |
85 | return pool_ ? pool_->CurrentThreadId() : 0; |
86 | } |
87 | |
88 | private: |
89 | // May be null if num_threads <= 1. |
90 | std::unique_ptr<Eigen::ThreadPool> pool_; |
91 | }; |
92 | |
93 | // Utility class for lazily creating an Eigen thread pool/device only when used. |
94 | class LazyEigenThreadPoolHolder { |
95 | public: |
96 | explicit LazyEigenThreadPoolHolder(int num_threads) { |
97 | SetNumThreads(num_threads); |
98 | } |
99 | |
100 | // Gets the ThreadPoolDevice, creating if necessary. |
101 | const Eigen::ThreadPoolDevice* GetThreadPoolDevice() { |
102 | if (!device_) { |
103 | thread_pool_wrapper_ = |
104 | std::make_unique<EigenThreadPoolWrapper>(target_num_threads_); |
105 | device_ = std::make_unique<Eigen::ThreadPoolDevice>( |
106 | thread_pool_wrapper_.get(), target_num_threads_); |
107 | } |
108 | return device_.get(); |
109 | } |
110 | |
111 | // Updates the thread count, invalidating the ThreadPoolDevice if necessary. |
112 | void SetNumThreads(int num_threads) { |
113 | const int target_num_threads = GetNumThreads(num_threads); |
114 | if (target_num_threads_ != target_num_threads) { |
115 | target_num_threads_ = target_num_threads; |
116 | // As the device references the thread pool wrapper, destroy it first. |
117 | device_.reset(); |
118 | thread_pool_wrapper_.reset(); |
119 | } |
120 | } |
121 | |
122 | private: |
123 | int target_num_threads_ = kDefaultNumThreadpoolThreads; |
124 | // Both device_ and thread_pool_wrapper_ are lazily created. |
125 | std::unique_ptr<Eigen::ThreadPoolDevice> device_; |
126 | std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper_; |
127 | }; |
128 | |
129 | struct RefCountedEigenContext : public TfLiteExternalContext { |
130 | std::unique_ptr<LazyEigenThreadPoolHolder> thread_pool_holder; |
131 | int num_references = 0; |
132 | }; |
133 | |
134 | RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { |
135 | return reinterpret_cast<RefCountedEigenContext*>( |
136 | context->GetExternalContext(context, kTfLiteEigenContext)); |
137 | } |
138 | |
139 | TfLiteStatus Refresh(TfLiteContext* context) { |
140 | if (IsValidNumThreads(context->recommended_num_threads)) { |
141 | SetEigenNbThreads(GetNumThreads(context->recommended_num_threads)); |
142 | } |
143 | |
144 | auto* ptr = GetEigenContext(context); |
145 | if (ptr != nullptr) { |
146 | ptr->thread_pool_holder->SetNumThreads(context->recommended_num_threads); |
147 | } |
148 | |
149 | return kTfLiteOk; |
150 | } |
151 | |
152 | } // namespace |
153 | |
154 | void IncrementUsageCounter(TfLiteContext* context) { |
155 | auto* ptr = GetEigenContext(context); |
156 | if (ptr == nullptr) { |
157 | if (IsValidNumThreads(context->recommended_num_threads)) { |
158 | SetEigenNbThreads(context->recommended_num_threads); |
159 | } |
160 | ptr = new RefCountedEigenContext; |
161 | ptr->type = kTfLiteEigenContext; |
162 | ptr->Refresh = Refresh; |
163 | ptr->thread_pool_holder = std::make_unique<LazyEigenThreadPoolHolder>( |
164 | context->recommended_num_threads); |
165 | ptr->num_references = 0; |
166 | context->SetExternalContext(context, kTfLiteEigenContext, ptr); |
167 | } |
168 | ptr->num_references++; |
169 | } |
170 | |
171 | void DecrementUsageCounter(TfLiteContext* context) { |
172 | auto* ptr = GetEigenContext(context); |
173 | if (ptr == nullptr) { |
174 | TF_LITE_FATAL( |
175 | "Call to DecrementUsageCounter() not preceded by " |
176 | "IncrementUsageCounter()" ); |
177 | } |
178 | if (--ptr->num_references == 0) { |
179 | delete ptr; |
180 | context->SetExternalContext(context, kTfLiteEigenContext, nullptr); |
181 | } |
182 | } |
183 | |
184 | const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { |
185 | auto* ptr = GetEigenContext(context); |
186 | if (ptr == nullptr) { |
187 | TF_LITE_FATAL( |
188 | "Call to GetFromContext() not preceded by IncrementUsageCounter()" ); |
189 | } |
190 | return ptr->thread_pool_holder->GetThreadPoolDevice(); |
191 | } |
192 | |
193 | } // namespace eigen_support |
194 | } // namespace tflite |
195 | |