1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | /// \file |
16 | /// Provides functionality to construct an interpreter for a model. |
17 | /// |
18 | #ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ |
19 | #define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ |
20 | |
21 | #include <map> |
22 | #include <memory> |
23 | #include <string> |
24 | #include <vector> |
25 | |
26 | #include "flatbuffers/flatbuffers.h" // from @flatbuffers |
27 | #include "tensorflow/lite/allocation.h" |
28 | #include "tensorflow/lite/c/common.h" |
29 | #include "tensorflow/lite/core/api/error_reporter.h" |
30 | #include "tensorflow/lite/core/api/op_resolver.h" |
31 | #include "tensorflow/lite/core/subgraph.h" |
32 | #include "tensorflow/lite/interpreter.h" |
33 | #include "tensorflow/lite/model_builder.h" |
34 | #include "tensorflow/lite/mutable_op_resolver.h" |
35 | #include "tensorflow/lite/schema/schema_generated.h" |
36 | #include "tensorflow/lite/stderr_reporter.h" |
37 | |
38 | namespace tflite { |
39 | |
40 | /// Build an interpreter capable of interpreting `model`. |
41 | /// |
42 | /// `model`: A model whose lifetime must be at least as long as any |
43 | /// interpreter(s) created by the builder. In principle multiple interpreters |
44 | /// can be made from a single model. |
45 | /// `op_resolver`: An instance that implements the `OpResolver` interface, which |
46 | /// maps custom op names and builtin op codes to op registrations. The |
47 | /// lifetime of the provided `op_resolver` object must be at least as long as |
48 | /// the `InterpreterBuilder`; unlike `model` and `error_reporter`, the |
49 | /// `op_resolver` does not need to exist for the duration of any created |
50 | /// `Interpreter` objects. |
51 | /// `error_reporter`: a functor that is called to report errors that handles |
52 | /// printf var arg semantics. The lifetime of the `error_reporter` object must |
53 | /// be greater than or equal to the `Interpreter` created by `operator()`. |
54 | /// `options_experimental`: Options that can change behavior of interpreter. |
55 | /// WARNING: this parameter is an experimental API and is subject to change. |
56 | /// |
57 | /// Returns a kTfLiteOk when successful and sets interpreter to a valid |
58 | /// Interpreter. Note: The user must ensure the lifetime of the model (and error |
59 | /// reporter, if provided) is at least as long as interpreter's lifetime, and |
60 | /// a single model instance may safely be used with multiple interpreters. |
61 | class InterpreterBuilder { |
62 | public: |
63 | /// For this constructor, the ErrorReporter will be extracted from the |
64 | /// FlatBufferModel. |
65 | /// `options` object is copied during construction. So caller can release it |
66 | // after calling the constructor. |
67 | InterpreterBuilder(const FlatBufferModel& model, |
68 | const OpResolver& op_resolver, |
69 | const InterpreterOptions* options_experimental = nullptr); |
70 | /// Builds an interpreter given only the raw flatbuffer Model object (instead |
71 | /// of a FlatBufferModel). Mostly used for testing. |
72 | /// If `error_reporter` is null, then DefaultErrorReporter() is used. |
73 | /// `options` object is copied during construction. So caller can release it |
74 | // after calling the constructor. |
75 | InterpreterBuilder(const ::tflite::Model* model, |
76 | const OpResolver& op_resolver, |
77 | ErrorReporter* error_reporter = DefaultErrorReporter(), |
78 | const InterpreterOptions* options_experimental = nullptr); |
79 | ~InterpreterBuilder(); |
80 | InterpreterBuilder(const InterpreterBuilder&) = delete; |
81 | InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; |
82 | |
83 | /// Builds an interpreter and stores it in `*interpreter`. |
84 | /// On success, returns kTfLiteOk and sets `*interpreter` to a valid |
85 | /// Interpreter. |
86 | /// On failure, returns an error status and sets `*interpreter` to nullptr. |
87 | TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter); |
88 | |
89 | /// Same as above, but also sets the number of CPU threads to use |
90 | /// (overriding any previous call to SetNumThreads). |
91 | /// Deprecated: use the SetNumThreads method instead. |
92 | TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter, |
93 | int num_threads); |
94 | |
95 | /// Sets the number of CPU threads to use for the interpreter. |
96 | /// Returns kTfLiteOk on success, kTfLiteError on error. |
97 | TfLiteStatus SetNumThreads(int num_threads); |
98 | |
99 | /// Any delegates added with AddDelegate will be applied to the Interpreter |
100 | /// generated by operator(), in the order that they were added. (The delegate |
101 | /// parameter passed to AddDelegate should be non-null, otherwise an error |
102 | /// will be reported, and the call to AddDelegate will have no other effect.) |
103 | /// The lifetime of the delegate must be at least as long as the lifetime of |
104 | /// any Interpreter generated by this InterpreterBuilder. |
105 | void AddDelegate(TfLiteDelegate* delegate); |
106 | |
107 | private: |
108 | TfLiteStatus BuildLocalIndexToRegistrationMapping(); |
109 | TfLiteStatus ParseNodes( |
110 | const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, |
111 | Subgraph* subgraph); |
112 | TfLiteStatus ParseTensors( |
113 | const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, |
114 | const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, |
115 | Subgraph* subgraph); |
116 | TfLiteStatus ApplyDelegates(Interpreter* interpreter); |
117 | TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, |
118 | TfLiteQuantization* quantization, |
119 | const std::vector<int>& dims); |
120 | TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity, |
121 | TfLiteSparsity** sparsity); |
122 | TfLiteStatus ParseSignatureDefs( |
123 | const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>* |
124 | signature_def_list, |
125 | Interpreter* interpreter); |
126 | |
127 | const ::tflite::Model* model_; |
128 | const OpResolver& op_resolver_; |
129 | ErrorReporter* error_reporter_; |
130 | std::vector<TfLiteDelegate*> delegates_; |
131 | // Model metadata stored as mapping of name (key) to buffer (value). |
132 | // Data is mapped from the Metadata in TFLite flatbuffer model. |
133 | // TODO(b/188185962): Consider mapping to std::pair<const char*, size_t> if |
134 | // this increases runtime memory usage for large metadata. |
135 | std::map<std::string, std::string> metadata_; |
136 | |
137 | std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_; |
138 | std::vector<TfLiteRegistration> unresolved_custom_ops_; |
139 | std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_; |
140 | const Allocation* allocation_ = nullptr; |
141 | |
142 | bool has_flex_op_ = false; |
143 | int num_fp32_tensors_ = 0; |
144 | int num_threads_ = -1; |
145 | InterpreterOptions options_; |
146 | }; |
147 | |
148 | } // namespace tflite |
149 | |
150 | #endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ |
151 | |