1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | /// \file |
16 | /// Deserialization infrastructure for tflite. Provides functionality |
17 | /// to go from a serialized tflite model in flatbuffer format to an |
18 | /// in-memory representation of the model. |
19 | /// |
20 | #ifndef TENSORFLOW_LITE_MODEL_BUILDER_H_ |
21 | #define TENSORFLOW_LITE_MODEL_BUILDER_H_ |
22 | |
23 | #include <stddef.h> |
24 | |
25 | #include <map> |
26 | #include <memory> |
27 | #include <string> |
28 | |
29 | #include "tensorflow/lite/allocation.h" |
30 | #include "tensorflow/lite/c/common.h" |
31 | #include "tensorflow/lite/core/api/error_reporter.h" |
32 | #include "tensorflow/lite/core/api/op_resolver.h" |
33 | #include "tensorflow/lite/core/api/verifier.h" |
34 | #include "tensorflow/lite/mutable_op_resolver.h" |
35 | #include "tensorflow/lite/schema/schema_generated.h" |
36 | #include "tensorflow/lite/stderr_reporter.h" |
37 | #include "tensorflow/lite/string_type.h" |
38 | |
39 | namespace tflite { |
40 | |
41 | /// An RAII object that represents a read-only tflite model, copied from disk, |
42 | /// or mmapped. This uses flatbuffers as the serialization format. |
43 | /// |
44 | /// NOTE: The current API requires that a FlatBufferModel instance be kept alive |
45 | /// by the client as long as it is in use by any dependent Interpreter |
46 | /// instances. As the FlatBufferModel instance is effectively immutable after |
47 | /// creation, the client may safely use a single model with multiple dependent |
48 | /// Interpreter instances, even across multiple threads (though note that each |
49 | /// Interpreter instance is *not* thread-safe). |
50 | /// |
51 | /// <pre><code> |
52 | /// using namespace tflite; |
53 | /// StderrReporter error_reporter; |
54 | /// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", |
55 | /// &error_reporter); |
56 | /// MyOpResolver resolver; // You need to subclass OpResolver to provide |
57 | /// // implementations. |
58 | /// InterpreterBuilder builder(*model, resolver); |
59 | /// std::unique_ptr<Interpreter> interpreter; |
60 | /// if(builder(&interpreter) == kTfLiteOk) { |
61 | /// .. run model inference with interpreter |
62 | /// } |
63 | /// </code></pre> |
64 | /// |
65 | /// OpResolver must be defined to provide your kernel implementations to the |
66 | /// interpreter. This is environment specific and may consist of just the |
67 | /// builtin ops, or some custom operators you defined to extend tflite. |
68 | class FlatBufferModel { |
69 | public: |
70 | /// Builds a model based on a file. |
71 | /// Caller retains ownership of `error_reporter` and must ensure its lifetime |
72 | /// is longer than the FlatBufferModel instance. |
73 | /// Returns a nullptr in case of failure. |
74 | static std::unique_ptr<FlatBufferModel> BuildFromFile( |
75 | const char* filename, |
76 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
77 | |
78 | /// Verifies whether the content of the file is legit, then builds a model |
79 | /// based on the file. |
80 | /// The extra_verifier argument is an additional optional verifier for the |
81 | /// file contents. By default, we always check with tflite::VerifyModelBuffer. |
82 | /// If extra_verifier is supplied, the file contents is also checked against |
83 | /// the extra_verifier after the check against tflite::VerifyModelBuilder. |
84 | /// Caller retains ownership of `error_reporter` and must ensure its lifetime |
85 | /// is longer than the FlatBufferModel instance. |
86 | /// Returns a nullptr in case of failure. |
87 | static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile( |
88 | const char* filename, TfLiteVerifier* = nullptr, |
89 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
90 | |
91 | /// Builds a model based on a pre-loaded flatbuffer. |
92 | /// Caller retains ownership of the buffer and should keep it alive until |
93 | /// the returned object is destroyed. Caller also retains ownership of |
94 | /// `error_reporter` and must ensure its lifetime is longer than the |
95 | /// FlatBufferModel instance. |
96 | /// Returns a nullptr in case of failure. |
97 | /// NOTE: this does NOT validate the buffer so it should NOT be called on |
98 | /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case |
99 | static std::unique_ptr<FlatBufferModel> BuildFromBuffer( |
100 | const char* caller_owned_buffer, size_t buffer_size, |
101 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
102 | |
103 | /// Verifies whether the content of the buffer is legit, then builds a model |
104 | /// based on the pre-loaded flatbuffer. |
105 | /// The extra_verifier argument is an additional optional verifier for the |
106 | /// buffer. By default, we always check with tflite::VerifyModelBuffer. If |
107 | /// extra_verifier is supplied, the buffer is checked against the |
108 | /// extra_verifier after the check against tflite::VerifyModelBuilder. The |
109 | /// caller retains ownership of the buffer and should keep it alive until the |
110 | /// returned object is destroyed. Caller retains ownership of `error_reporter` |
111 | /// and must ensure its lifetime is longer than the FlatBufferModel instance. |
112 | /// Returns a nullptr in case of failure. |
113 | static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer( |
114 | const char* caller_owned_buffer, size_t buffer_size, |
115 | TfLiteVerifier* = nullptr, |
116 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
117 | |
118 | /// Builds a model directly from an allocation. |
119 | /// Ownership of the allocation is passed to the model, but the caller |
120 | /// retains ownership of `error_reporter` and must ensure its lifetime is |
121 | /// longer than the FlatBufferModel instance. |
122 | /// Returns a nullptr in case of failure (e.g., the allocation is invalid). |
123 | static std::unique_ptr<FlatBufferModel> BuildFromAllocation( |
124 | std::unique_ptr<Allocation> allocation, |
125 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
126 | |
127 | /// Verifies whether the content of the allocation is legit, then builds a |
128 | /// model based on the provided allocation. |
129 | /// The extra_verifier argument is an additional optional verifier for the |
130 | /// buffer. By default, we always check with tflite::VerifyModelBuffer. If |
131 | /// extra_verifier is supplied, the buffer is checked against the |
132 | /// extra_verifier after the check against tflite::VerifyModelBuilder. |
133 | /// Ownership of the allocation is passed to the model, but the caller |
134 | /// retains ownership of `error_reporter` and must ensure its lifetime is |
135 | /// longer than the FlatBufferModel instance. |
136 | /// Returns a nullptr in case of failure. |
137 | static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation( |
138 | std::unique_ptr<Allocation> allocation, |
139 | TfLiteVerifier* = nullptr, |
140 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
141 | |
142 | /// Builds a model directly from a flatbuffer pointer |
143 | /// Caller retains ownership of the buffer and should keep it alive until the |
144 | /// returned object is destroyed. Caller retains ownership of `error_reporter` |
145 | /// and must ensure its lifetime is longer than the FlatBufferModel instance. |
146 | /// Returns a nullptr in case of failure. |
147 | static std::unique_ptr<FlatBufferModel> BuildFromModel( |
148 | const tflite::Model* caller_owned_model_spec, |
149 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
150 | |
151 | // Releases memory or unmaps mmaped memory. |
152 | ~FlatBufferModel(); |
153 | |
154 | // Copying or assignment is disallowed to simplify ownership semantics. |
155 | FlatBufferModel(const FlatBufferModel&) = delete; |
156 | FlatBufferModel& operator=(const FlatBufferModel&) = delete; |
157 | |
158 | bool initialized() const { return model_ != nullptr; } |
159 | const tflite::Model* operator->() const { return model_; } |
160 | const tflite::Model* GetModel() const { return model_; } |
161 | ErrorReporter* error_reporter() const { return error_reporter_; } |
162 | const Allocation* allocation() const { return allocation_.get(); } |
163 | |
164 | // Returns the minimum runtime version from the flatbuffer. This runtime |
165 | // version encodes the minimum required interpreter version to run the |
166 | // flatbuffer model. If the minimum version can't be determined, an empty |
167 | // string will be returned. |
168 | // Note that the returned minimum version is a lower-bound but not a strict |
169 | // lower-bound; ops in the graph may not have an associated runtime version, |
170 | // in which case the actual required runtime might be greater than the |
171 | // reported minimum. |
172 | std::string GetMinimumRuntime() const; |
173 | |
174 | // Return model metadata as a mapping of name & buffer strings. |
175 | // See Metadata table in TFLite schema. |
176 | std::map<std::string, std::string> ReadAllMetadata() const; |
177 | |
178 | /// Returns true if the model identifier is correct (otherwise false and |
179 | /// reports an error). |
180 | bool CheckModelIdentifier() const; |
181 | |
182 | private: |
183 | /// Loads a model from a given allocation. FlatBufferModel will take over the |
184 | /// ownership of `allocation`, and delete it in destructor. The ownership of |
185 | /// `error_reporter`remains with the caller and must have lifetime at least |
186 | /// as much as FlatBufferModel. This is to allow multiple models to use the |
187 | /// same ErrorReporter instance. |
188 | explicit FlatBufferModel( |
189 | std::unique_ptr<Allocation> allocation, |
190 | ErrorReporter* error_reporter = DefaultErrorReporter()); |
191 | |
192 | /// Loads a model from Model flatbuffer. The `model` has to remain alive and |
193 | /// unchanged until the end of this flatbuffermodel's lifetime. |
194 | FlatBufferModel(const Model* model, ErrorReporter* error_reporter); |
195 | |
196 | /// Flatbuffer traverser pointer. (Model* is a pointer that is within the |
197 | /// allocated memory of the data allocated by allocation's internals. |
198 | const tflite::Model* model_ = nullptr; |
199 | /// The error reporter to use for model errors and subsequent errors when |
200 | /// the interpreter is created |
201 | ErrorReporter* error_reporter_; |
202 | /// The allocator used for holding memory of the model. Note that this will |
203 | /// be null if the client provides a tflite::Model directly. |
204 | std::unique_ptr<Allocation> allocation_; |
205 | }; |
206 | |
207 | } // namespace tflite |
208 | |
209 | #endif // TENSORFLOW_LITE_MODEL_BUILDER_H_ |
210 | |