1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/model_builder.h" |
16 | |
17 | #include <stddef.h> |
18 | #include <stdint.h> |
19 | |
20 | #include <memory> |
21 | #include <string> |
22 | #include <utility> |
23 | |
24 | #include "flatbuffers/flatbuffers.h" // from @flatbuffers |
25 | #include "tensorflow/lite/allocation.h" |
26 | #include "tensorflow/lite/core/api/error_reporter.h" |
27 | #include "tensorflow/lite/core/api/verifier.h" |
28 | #include "tensorflow/lite/schema/schema_generated.h" |
29 | #include "tensorflow/lite/stderr_reporter.h" |
30 | #include "tensorflow/lite/string_type.h" |
31 | |
32 | namespace tflite { |
33 | |
34 | namespace { |
35 | |
36 | // Ensure that ErrorReporter is non-null. |
37 | ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { |
38 | return e ? e : DefaultErrorReporter(); |
39 | } |
40 | |
41 | } // namespace |
42 | |
43 | #ifndef TFLITE_MCU |
44 | // Loads a model from `filename`. If `mmap_file` is true then use mmap, |
45 | // otherwise make a copy of the model in a buffer. |
46 | std::unique_ptr<Allocation> GetAllocationFromFile( |
47 | const char* filename, ErrorReporter* error_reporter) { |
48 | std::unique_ptr<Allocation> allocation; |
49 | if (MMAPAllocation::IsSupported()) { |
50 | allocation = std::make_unique<MMAPAllocation>(filename, error_reporter); |
51 | } else { |
52 | allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter); |
53 | } |
54 | return allocation; |
55 | } |
56 | |
57 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile( |
58 | const char* filename, ErrorReporter* error_reporter) { |
59 | error_reporter = ValidateErrorReporter(error_reporter); |
60 | return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter), |
61 | error_reporter); |
62 | } |
63 | |
64 | std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile( |
65 | const char* filename, TfLiteVerifier* , |
66 | ErrorReporter* error_reporter) { |
67 | error_reporter = ValidateErrorReporter(error_reporter); |
68 | return VerifyAndBuildFromAllocation( |
69 | GetAllocationFromFile(filename, error_reporter), extra_verifier, |
70 | error_reporter); |
71 | } |
72 | #endif |
73 | |
74 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer( |
75 | const char* caller_owned_buffer, size_t buffer_size, |
76 | ErrorReporter* error_reporter) { |
77 | error_reporter = ValidateErrorReporter(error_reporter); |
78 | std::unique_ptr<Allocation> allocation( |
79 | new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); |
80 | return BuildFromAllocation(std::move(allocation), error_reporter); |
81 | } |
82 | |
83 | std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer( |
84 | const char* caller_owned_buffer, size_t buffer_size, |
85 | TfLiteVerifier* , ErrorReporter* error_reporter) { |
86 | error_reporter = ValidateErrorReporter(error_reporter); |
87 | std::unique_ptr<Allocation> allocation( |
88 | new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); |
89 | return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier, |
90 | error_reporter); |
91 | } |
92 | |
93 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation( |
94 | std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) { |
95 | std::unique_ptr<FlatBufferModel> model(new FlatBufferModel( |
96 | std::move(allocation), ValidateErrorReporter(error_reporter))); |
97 | if (!model->initialized()) { |
98 | model.reset(); |
99 | } |
100 | return model; |
101 | } |
102 | |
103 | std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation( |
104 | std::unique_ptr<Allocation> allocation, TfLiteVerifier* , |
105 | ErrorReporter* error_reporter) { |
106 | error_reporter = ValidateErrorReporter(error_reporter); |
107 | if (!allocation || !allocation->valid()) { |
108 | TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty" ); |
109 | return nullptr; |
110 | } |
111 | |
112 | flatbuffers::Verifier base_verifier( |
113 | reinterpret_cast<const uint8_t*>(allocation->base()), |
114 | allocation->bytes()); |
115 | if (!VerifyModelBuffer(base_verifier)) { |
116 | TF_LITE_REPORT_ERROR(error_reporter, |
117 | "The model is not a valid Flatbuffer buffer" ); |
118 | return nullptr; |
119 | } |
120 | |
121 | if (extra_verifier && |
122 | !extra_verifier->Verify(static_cast<const char*>(allocation->base()), |
123 | allocation->bytes(), error_reporter)) { |
124 | // The verifier will have already logged an appropriate error message. |
125 | return nullptr; |
126 | } |
127 | |
128 | return BuildFromAllocation(std::move(allocation), error_reporter); |
129 | } |
130 | |
131 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel( |
132 | const tflite::Model* caller_owned_model_spec, |
133 | ErrorReporter* error_reporter) { |
134 | error_reporter = ValidateErrorReporter(error_reporter); |
135 | |
136 | std::unique_ptr<FlatBufferModel> model( |
137 | new FlatBufferModel(caller_owned_model_spec, error_reporter)); |
138 | if (!model->initialized()) { |
139 | model.reset(); |
140 | } |
141 | return model; |
142 | } |
143 | |
144 | string FlatBufferModel::GetMinimumRuntime() const { |
145 | if (!model_ || !model_->metadata()) return "" ; |
146 | |
147 | for (int i = 0; i < model_->metadata()->size(); ++i) { |
148 | auto metadata = model_->metadata()->Get(i); |
149 | if (metadata->name()->str() == "min_runtime_version" ) { |
150 | auto buf = metadata->buffer(); |
151 | auto* buffer = (*model_->buffers())[buf]; |
152 | auto* array = buffer->data(); |
153 | // Get the real length of the runtime string, since there might be |
154 | // trailing |
155 | // '\0's in the buffer. |
156 | for (int len = 0; len < array->size(); ++len) { |
157 | if (array->data()[len] == '\0') { |
158 | return string(reinterpret_cast<const char*>(array->data()), len); |
159 | } |
160 | } |
161 | // If there is no '\0' in the buffer, this indicates that the flatbuffer |
162 | // is malformed. |
163 | TF_LITE_REPORT_ERROR( |
164 | error_reporter_, |
165 | "Min_runtime_version in model metadata is malformed" ); |
166 | break; |
167 | } |
168 | } |
169 | return "" ; |
170 | } |
171 | |
172 | std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata() const { |
173 | std::map<std::string, std::string> keys_values; |
174 | if (!model_ || !model_->metadata() || !model_->buffers()) return keys_values; |
175 | |
176 | for (int i = 0; i < model_->metadata()->size(); ++i) { |
177 | auto metadata = model_->metadata()->Get(i); |
178 | auto buf = metadata->buffer(); |
179 | if (buf >= model_->buffers()->size()) continue; |
180 | const tflite::Buffer* buffer = (*model_->buffers())[buf]; |
181 | if (!buffer || !buffer->data()) continue; |
182 | const flatbuffers::Vector<uint8_t>* array = buffer->data(); |
183 | if (!array) continue; |
184 | std::string val = |
185 | string(reinterpret_cast<const char*>(array->data()), array->size()); |
186 | // Skip if key or value of metadata is empty. |
187 | if (!metadata->name() || val.empty()) continue; |
188 | keys_values[metadata->name()->str()] = val; |
189 | } |
190 | return keys_values; |
191 | } |
192 | |
193 | bool FlatBufferModel::CheckModelIdentifier() const { |
194 | if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { |
195 | const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); |
196 | error_reporter_->Report( |
197 | "Model provided has model identifier '%c%c%c%c', should be '%s'\n" , |
198 | ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); |
199 | return false; |
200 | } |
201 | return true; |
202 | } |
203 | |
204 | FlatBufferModel::FlatBufferModel(const Model* model, |
205 | ErrorReporter* error_reporter) |
206 | : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {} |
207 | |
208 | FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation, |
209 | ErrorReporter* error_reporter) |
210 | : error_reporter_(ValidateErrorReporter(error_reporter)), |
211 | allocation_(std::move(allocation)) { |
212 | if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) { |
213 | return; |
214 | } |
215 | |
216 | model_ = ::tflite::GetModel(allocation_->base()); |
217 | } |
218 | |
219 | FlatBufferModel::~FlatBufferModel() {} |
220 | |
221 | } // namespace tflite |
222 | |