1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15/// \file
16/// Provides functionality to construct an interpreter for a model.
17///
18#ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
19#define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
20
21#include <map>
22#include <memory>
23#include <string>
24#include <vector>
25
26#include "flatbuffers/flatbuffers.h" // from @flatbuffers
27#include "tensorflow/lite/allocation.h"
28#include "tensorflow/lite/c/common.h"
29#include "tensorflow/lite/core/api/error_reporter.h"
30#include "tensorflow/lite/core/api/op_resolver.h"
31#include "tensorflow/lite/core/subgraph.h"
32#include "tensorflow/lite/interpreter.h"
33#include "tensorflow/lite/model_builder.h"
34#include "tensorflow/lite/mutable_op_resolver.h"
35#include "tensorflow/lite/schema/schema_generated.h"
36#include "tensorflow/lite/stderr_reporter.h"
37
38namespace tflite {
39
40/// Build an interpreter capable of interpreting `model`.
41///
42/// `model`: A model whose lifetime must be at least as long as any
43/// interpreter(s) created by the builder. In principle multiple interpreters
44/// can be made from a single model.
45/// `op_resolver`: An instance that implements the `OpResolver` interface, which
46/// maps custom op names and builtin op codes to op registrations. The
47/// lifetime of the provided `op_resolver` object must be at least as long as
48/// the `InterpreterBuilder`; unlike `model` and `error_reporter`, the
49/// `op_resolver` does not need to exist for the duration of any created
50/// `Interpreter` objects.
51/// `error_reporter`: a functor that is called to report errors that handles
52/// printf var arg semantics. The lifetime of the `error_reporter` object must
53/// be greater than or equal to the `Interpreter` created by `operator()`.
54/// `options_experimental`: Options that can change behavior of interpreter.
55/// WARNING: this parameter is an experimental API and is subject to change.
56///
57/// Returns a kTfLiteOk when successful and sets interpreter to a valid
58/// Interpreter. Note: The user must ensure the lifetime of the model (and error
59/// reporter, if provided) is at least as long as interpreter's lifetime, and
60/// a single model instance may safely be used with multiple interpreters.
61class InterpreterBuilder {
62 public:
63 /// For this constructor, the ErrorReporter will be extracted from the
64 /// FlatBufferModel.
65 /// `options` object is copied during construction. So caller can release it
66 // after calling the constructor.
67 InterpreterBuilder(const FlatBufferModel& model,
68 const OpResolver& op_resolver,
69 const InterpreterOptions* options_experimental = nullptr);
70 /// Builds an interpreter given only the raw flatbuffer Model object (instead
71 /// of a FlatBufferModel). Mostly used for testing.
72 /// If `error_reporter` is null, then DefaultErrorReporter() is used.
73 /// `options` object is copied during construction. So caller can release it
74 // after calling the constructor.
75 InterpreterBuilder(const ::tflite::Model* model,
76 const OpResolver& op_resolver,
77 ErrorReporter* error_reporter = DefaultErrorReporter(),
78 const InterpreterOptions* options_experimental = nullptr);
79 ~InterpreterBuilder();
80 InterpreterBuilder(const InterpreterBuilder&) = delete;
81 InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
82
83 /// Builds an interpreter and stores it in `*interpreter`.
84 /// On success, returns kTfLiteOk and sets `*interpreter` to a valid
85 /// Interpreter.
86 /// On failure, returns an error status and sets `*interpreter` to nullptr.
87 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
88
89 /// Same as above, but also sets the number of CPU threads to use
90 /// (overriding any previous call to SetNumThreads).
91 /// Deprecated: use the SetNumThreads method instead.
92 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
93 int num_threads);
94
95 /// Sets the number of CPU threads to use for the interpreter.
96 /// Returns kTfLiteOk on success, kTfLiteError on error.
97 TfLiteStatus SetNumThreads(int num_threads);
98
99 /// Any delegates added with AddDelegate will be applied to the Interpreter
100 /// generated by operator(), in the order that they were added. (The delegate
101 /// parameter passed to AddDelegate should be non-null, otherwise an error
102 /// will be reported, and the call to AddDelegate will have no other effect.)
103 /// The lifetime of the delegate must be at least as long as the lifetime of
104 /// any Interpreter generated by this InterpreterBuilder.
105 void AddDelegate(TfLiteDelegate* delegate);
106
107 private:
108 TfLiteStatus BuildLocalIndexToRegistrationMapping();
109 TfLiteStatus ParseNodes(
110 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
111 Subgraph* subgraph);
112 TfLiteStatus ParseTensors(
113 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
114 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
115 Subgraph* subgraph);
116 TfLiteStatus ApplyDelegates(Interpreter* interpreter);
117 TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
118 TfLiteQuantization* quantization,
119 const std::vector<int>& dims);
120 TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
121 TfLiteSparsity** sparsity);
122 TfLiteStatus ParseSignatureDefs(
123 const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
124 signature_def_list,
125 Interpreter* interpreter);
126
127 const ::tflite::Model* model_;
128 const OpResolver& op_resolver_;
129 ErrorReporter* error_reporter_;
130 std::vector<TfLiteDelegate*> delegates_;
131 // Model metadata stored as mapping of name (key) to buffer (value).
132 // Data is mapped from the Metadata in TFLite flatbuffer model.
133 // TODO(b/188185962): Consider mapping to std::pair<const char*, size_t> if
134 // this increases runtime memory usage for large metadata.
135 std::map<std::string, std::string> metadata_;
136
137 std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
138 std::vector<TfLiteRegistration> unresolved_custom_ops_;
139 std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
140 const Allocation* allocation_ = nullptr;
141
142 bool has_flex_op_ = false;
143 int num_fp32_tensors_ = 0;
144 int num_threads_ = -1;
145 InterpreterOptions options_;
146};
147
148} // namespace tflite
149
150#endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
151