1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
17 | |
18 | #include <memory> |
19 | |
20 | #ifdef TFLITE_HAVE_CPUINFO |
21 | #include "include/cpuinfo.h" |
22 | #endif |
23 | |
24 | #include "public/gemmlowp.h" |
25 | #include "ruy/context.h" // from @ruy |
26 | #include "ruy/path.h" // from @ruy |
27 | #include "tensorflow/lite/c/common.h" |
28 | #include "tensorflow/lite/core/macros.h" |
29 | #include "tensorflow/lite/external_cpu_backend_context.h" |
30 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
31 | #include "tensorflow/lite/kernels/op_macros.h" |
32 | |
33 | namespace { |
34 | const int kDefaultNumThreadpoolThreads = 1; |
35 | |
36 | } // namespace |
37 | |
38 | namespace tflite { |
39 | |
40 | // Use weak symbols if possible to dispatch to deprecated paths. |
41 | #if TFLITE_HAS_ATTRIBUTE_WEAK && !defined(__APPLE__) |
42 | extern TFLITE_ATTRIBUTE_WEAK bool UseGemmlowpOnX86(); |
43 | #endif // defined(TFLITE_HAS_ATTRIBUTE_WEAK) && !(__APPLE__) |
44 | |
45 | // TODO(b/138922878) Enable when Ruy builds on Apple. |
46 | #if defined(TFLITE_HAVE_CPUINFO) && !defined(__APPLE__) |
47 | CpuBackendContext::CpuInfo::~CpuInfo() { |
48 | if (init_status_ == InitStatus::kInitialized) { |
49 | cpuinfo_deinitialize(); |
50 | } |
51 | } |
52 | |
53 | bool CpuBackendContext::CpuInfo::EnsureInitialized() { |
54 | if (init_status_ == InitStatus::kNotYetAttempted) { |
55 | init_status_ = Initialize(); |
56 | } |
57 | return init_status_ == InitStatus::kInitialized; |
58 | } |
59 | |
60 | CpuBackendContext::CpuInfo::InitStatus |
61 | CpuBackendContext::CpuInfo::Initialize() { |
62 | TFLITE_DCHECK_EQ(init_status_, InitStatus::kNotYetAttempted); |
63 | if (!cpuinfo_initialize()) { |
64 | return InitStatus::kFailed; |
65 | } |
66 | return InitStatus::kInitialized; |
67 | } |
68 | |
69 | bool CpuBackendContext::CpuInfo::Avx2Fma() { |
70 | return EnsureInitialized() && cpuinfo_has_x86_avx2() && |
71 | cpuinfo_has_x86_fma3(); |
72 | } |
73 | |
74 | bool CpuBackendContext::CpuInfo::Avx() { |
75 | return EnsureInitialized() && cpuinfo_has_x86_avx(); |
76 | } |
77 | |
78 | bool CpuBackendContext::CpuInfo::Avx512() { |
79 | return EnsureInitialized() && cpuinfo_has_x86_avx512f() && |
80 | cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() && |
81 | cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512vl(); |
82 | } |
83 | #else |
84 | |
85 | CpuBackendContext::CpuInfo::~CpuInfo() {} |
86 | |
87 | bool CpuBackendContext::CpuInfo::EnsureInitialized() { |
88 | if (init_status_ == InitStatus::kNotYetAttempted) { |
89 | init_status_ = InitStatus::kInitialized; |
90 | } |
91 | TFLITE_DCHECK_EQ(init_status_, InitStatus::kInitialized); |
92 | return true; |
93 | } |
94 | |
95 | bool CpuBackendContext::CpuInfo::Avx2Fma() { return false; } |
96 | |
97 | bool CpuBackendContext::CpuInfo::Avx() { return false; } |
98 | |
99 | bool CpuBackendContext::CpuInfo::Avx512() { return false; } |
100 | #endif // TFLITE_HAVE_CPUINFO |
101 | |
102 | CpuBackendContext* CpuBackendContext::GetFromContext(TfLiteContext* context) { |
103 | auto* external_context = static_cast<ExternalCpuBackendContext*>( |
104 | context->GetExternalContext(context, kTfLiteCpuBackendContext)); |
105 | |
106 | if (external_context == nullptr) { |
107 | TF_LITE_FATAL( |
108 | "ExternalCpuBackendContext isn't properly initialized during TFLite " |
109 | "interpreter initialization."); |
110 | } |
111 | |
112 | auto* cpu_backend_context = static_cast<CpuBackendContext*>( |
113 | external_context->internal_backend_context()); |
114 | if (cpu_backend_context == nullptr) { |
115 | // We do the lazy initialization here for the TfLiteInternalBackendContext |
116 | // that's wrapped inside ExternalCpuBackendContext. |
117 | cpu_backend_context = new CpuBackendContext(); |
118 | cpu_backend_context->SetMaxNumThreads(context->recommended_num_threads); |
119 | external_context->set_internal_backend_context( |
120 | std::unique_ptr<TfLiteInternalBackendContext>(cpu_backend_context)); |
121 | } |
122 | |
123 | return cpu_backend_context; |
124 | } |
125 | |
126 | CpuBackendContext::CpuBackendContext() |
127 | : TfLiteInternalBackendContext(), |
128 | ruy_context_(new ruy::Context), |
129 | gemmlowp_context_(new gemmlowp::GemmContext) { |
130 | SetMaxNumThreads(kDefaultNumThreadpoolThreads); |
131 | // TODO(b/148289189) Remove when clients have transitioned to runtime flag. |
132 | #ifdef TFLITE_WITH_RUY_GEMV |
133 | SetUseCaching(true); |
134 | #else |
135 | SetUseCaching(false); |
136 | #endif |
137 | } |
138 | |
139 | CpuBackendContext::~CpuBackendContext() {} |
140 | |
141 | void CpuBackendContext::SetMaxNumThreads(int max_num_threads) { |
142 | const int target_num_threads = |
143 | max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads; |
144 | max_num_threads_ = target_num_threads; |
145 | ruy_context_->set_max_num_threads(target_num_threads); |
146 | gemmlowp_context_->set_max_num_threads(target_num_threads); |
147 | } |
148 | |
149 | void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; } |
150 | |
151 | bool CpuBackendContext::PreferGemmlowpOnX86() { |
152 | bool use_gemmlowp_on_x86 = false; |
153 | #if defined(TFLITE_X86_PLATFORM) && TFLITE_HAS_ATTRIBUTE_WEAK && \ |
154 | !defined(__APPLE__) |
155 | if (::tflite::UseGemmlowpOnX86 != nullptr) { |
156 | use_gemmlowp_on_x86 = ::tflite::UseGemmlowpOnX86(); |
157 | } |
158 | #endif // TFLITE_X86_PLATFORM && TFLITE_HAS_ATTRIBUTE_WEAK && !(__APPLE__) |
159 | return use_gemmlowp_on_x86 || !RuyHasAvxOrAbove(); |
160 | } |
161 | |
162 | bool CpuBackendContext::RuyHasAvxOrAbove() { |
163 | // TODO(b/183178387): Use a proper query to detect AVX/optimized paths. |
164 | #if RUY_PLATFORM_X86_ENHANCEMENTS |
165 | return cpuinfo_.Avx() || cpuinfo_.Avx2Fma() || cpuinfo_.Avx512(); |
166 | #else |
167 | return false; |
168 | #endif |
169 | } |
170 | |
171 | } // namespace tflite |
172 |