1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
17#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
18
19#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
20 defined(_M_X64))
21#define TFLITE_X86_PLATFORM
22#endif
23
24#include <memory>
25
26#include "public/gemmlowp.h"
27#include "ruy/context.h" // from @ruy
28#include "tensorflow/lite/c/common.h"
29#include "tensorflow/lite/external_cpu_backend_context.h"
30
31namespace tflite {
32
33class CpuBackendContext final : public TfLiteInternalBackendContext {
34 public:
35 static CpuBackendContext* GetFromContext(TfLiteContext* context);
36
37 CpuBackendContext();
38 ~CpuBackendContext() override;
39
40 ruy::Context* ruy_context() const { return ruy_context_.get(); }
41
42 gemmlowp::GemmContext* gemmlowp_context() const {
43 return gemmlowp_context_.get();
44 }
45
46 // Sets the maximum-number-of-threads-to-use parameter, only as a means of
47 // passing around this information.
48 void SetMaxNumThreads(int max_num_threads) override;
49
50 int max_num_threads() const { return max_num_threads_; }
51
52 void SetUseCaching(bool flag);
53
54 bool use_caching() const { return use_caching_; }
55
56 void ClearCaches() override { ruy_context_->ClearPrepackedCache(); }
57
58 // Gemmlowp on x86 is a deprecated path but some clients may still use
59 // this path based on link time dependencies.
60 bool PreferGemmlowpOnX86();
61
62 private:
63 bool RuyHasAvxOrAbove();
64
65 // Copy the wrapper class for cpuinfo from Ruy.
66 class CpuInfo final {
67 public:
68 CpuInfo() {}
69 ~CpuInfo();
70
71 // X86 features
72 bool Avx();
73 bool Avx2Fma();
74 bool Avx512();
75
76 private:
77 enum class InitStatus {
78 kNotYetAttempted,
79 kInitialized,
80 kFailed,
81 };
82
83 InitStatus init_status_ = InitStatus::kNotYetAttempted;
84
85 bool EnsureInitialized();
86 InitStatus Initialize();
87 CpuInfo(const CpuInfo&) = delete;
88 CpuInfo& operator=(const CpuInfo&) = delete;
89 };
90
91 // To enable a smooth transition from the current direct usage
92 // of the underlying gemmlowp context to going through abstractions
93 // (see :cpu_backend_gemm), for now a CpuBackendContext always
94 // stores both a gemmlowp context and a ruy context.
95 // TODO(b/131416458): Once call sites all go through abstractions,
96 // elide what can be elided based on TFLITE_WITH_RUY.
97 const std::unique_ptr<ruy::Context> ruy_context_;
98 const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
99 CpuInfo cpuinfo_;
100
101 // The maximum of threads used for parallelizing TfLite ops. However,
102 // cpu_backend_threadpool::Execute creates as many threads as it's
103 // asked to, regardless of this. Typically a call site would query
104 // cpu_backend_context->max_num_threads() and used that to determine
105 // the number of tasks to create and to give to
106 // cpu_backend_threadpool::Execute.
107 //
108 // This value also gets propagated to back-ends, where it plays the same
109 // information-only role.
110 int max_num_threads_;
111 // For matrix muliplications with constants parameters (i.e. weights), we can
112 // sometimes provide speedups by caching the "prepacked" data, for some
113 // additional memory cost. This flag permits the user to route all
114 // CpuBackendGem operations to a library that permits such an optimization
115 // (currently the Ruy library only).
116 bool use_caching_;
117
118 CpuBackendContext(const CpuBackendContext&) = delete;
119};
120
121} // namespace tflite
122
123#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
124