1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/model_builder.h"
16
17#include <stddef.h>
18#include <stdint.h>
19
20#include <memory>
21#include <string>
22#include <utility>
23
24#include "flatbuffers/flatbuffers.h" // from @flatbuffers
25#include "tensorflow/lite/allocation.h"
26#include "tensorflow/lite/core/api/error_reporter.h"
27#include "tensorflow/lite/core/api/verifier.h"
28#include "tensorflow/lite/schema/schema_generated.h"
29#include "tensorflow/lite/stderr_reporter.h"
30#include "tensorflow/lite/string_type.h"
31
32namespace tflite {
33
34namespace {
35
36// Ensure that ErrorReporter is non-null.
37ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
38 return e ? e : DefaultErrorReporter();
39}
40
41} // namespace
42
43#ifndef TFLITE_MCU
44// Loads a model from `filename`. If `mmap_file` is true then use mmap,
45// otherwise make a copy of the model in a buffer.
46std::unique_ptr<Allocation> GetAllocationFromFile(
47 const char* filename, ErrorReporter* error_reporter) {
48 std::unique_ptr<Allocation> allocation;
49 if (MMAPAllocation::IsSupported()) {
50 allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
51 } else {
52 allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
53 }
54 return allocation;
55}
56
57std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
58 const char* filename, ErrorReporter* error_reporter) {
59 error_reporter = ValidateErrorReporter(error_reporter);
60 return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
61 error_reporter);
62}
63
64std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
65 const char* filename, TfLiteVerifier* extra_verifier,
66 ErrorReporter* error_reporter) {
67 error_reporter = ValidateErrorReporter(error_reporter);
68 return VerifyAndBuildFromAllocation(
69 GetAllocationFromFile(filename, error_reporter), extra_verifier,
70 error_reporter);
71}
72#endif
73
74std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
75 const char* caller_owned_buffer, size_t buffer_size,
76 ErrorReporter* error_reporter) {
77 error_reporter = ValidateErrorReporter(error_reporter);
78 std::unique_ptr<Allocation> allocation(
79 new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
80 return BuildFromAllocation(std::move(allocation), error_reporter);
81}
82
83std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
84 const char* caller_owned_buffer, size_t buffer_size,
85 TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
86 error_reporter = ValidateErrorReporter(error_reporter);
87 std::unique_ptr<Allocation> allocation(
88 new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
89 return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
90 error_reporter);
91}
92
93std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
94 std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
95 std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
96 std::move(allocation), ValidateErrorReporter(error_reporter)));
97 if (!model->initialized()) {
98 model.reset();
99 }
100 return model;
101}
102
103std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
104 std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
105 ErrorReporter* error_reporter) {
106 error_reporter = ValidateErrorReporter(error_reporter);
107 if (!allocation || !allocation->valid()) {
108 TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
109 return nullptr;
110 }
111
112 flatbuffers::Verifier base_verifier(
113 reinterpret_cast<const uint8_t*>(allocation->base()),
114 allocation->bytes());
115 if (!VerifyModelBuffer(base_verifier)) {
116 TF_LITE_REPORT_ERROR(error_reporter,
117 "The model is not a valid Flatbuffer buffer");
118 return nullptr;
119 }
120
121 if (extra_verifier &&
122 !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
123 allocation->bytes(), error_reporter)) {
124 // The verifier will have already logged an appropriate error message.
125 return nullptr;
126 }
127
128 return BuildFromAllocation(std::move(allocation), error_reporter);
129}
130
131std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
132 const tflite::Model* caller_owned_model_spec,
133 ErrorReporter* error_reporter) {
134 error_reporter = ValidateErrorReporter(error_reporter);
135
136 std::unique_ptr<FlatBufferModel> model(
137 new FlatBufferModel(caller_owned_model_spec, error_reporter));
138 if (!model->initialized()) {
139 model.reset();
140 }
141 return model;
142}
143
144string FlatBufferModel::GetMinimumRuntime() const {
145 if (!model_ || !model_->metadata()) return "";
146
147 for (int i = 0; i < model_->metadata()->size(); ++i) {
148 auto metadata = model_->metadata()->Get(i);
149 if (metadata->name()->str() == "min_runtime_version") {
150 auto buf = metadata->buffer();
151 auto* buffer = (*model_->buffers())[buf];
152 auto* array = buffer->data();
153 // Get the real length of the runtime string, since there might be
154 // trailing
155 // '\0's in the buffer.
156 for (int len = 0; len < array->size(); ++len) {
157 if (array->data()[len] == '\0') {
158 return string(reinterpret_cast<const char*>(array->data()), len);
159 }
160 }
161 // If there is no '\0' in the buffer, this indicates that the flatbuffer
162 // is malformed.
163 TF_LITE_REPORT_ERROR(
164 error_reporter_,
165 "Min_runtime_version in model metadata is malformed");
166 break;
167 }
168 }
169 return "";
170}
171
172std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata() const {
173 std::map<std::string, std::string> keys_values;
174 if (!model_ || !model_->metadata() || !model_->buffers()) return keys_values;
175
176 for (int i = 0; i < model_->metadata()->size(); ++i) {
177 auto metadata = model_->metadata()->Get(i);
178 auto buf = metadata->buffer();
179 if (buf >= model_->buffers()->size()) continue;
180 const tflite::Buffer* buffer = (*model_->buffers())[buf];
181 if (!buffer || !buffer->data()) continue;
182 const flatbuffers::Vector<uint8_t>* array = buffer->data();
183 if (!array) continue;
184 std::string val =
185 string(reinterpret_cast<const char*>(array->data()), array->size());
186 // Skip if key or value of metadata is empty.
187 if (!metadata->name() || val.empty()) continue;
188 keys_values[metadata->name()->str()] = val;
189 }
190 return keys_values;
191}
192
193bool FlatBufferModel::CheckModelIdentifier() const {
194 if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
195 const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
196 error_reporter_->Report(
197 "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
198 ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
199 return false;
200 }
201 return true;
202}
203
204FlatBufferModel::FlatBufferModel(const Model* model,
205 ErrorReporter* error_reporter)
206 : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
207
208FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
209 ErrorReporter* error_reporter)
210 : error_reporter_(ValidateErrorReporter(error_reporter)),
211 allocation_(std::move(allocation)) {
212 if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
213 return;
214 }
215
216 model_ = ::tflite::GetModel(allocation_->base());
217}
218
219FlatBufferModel::~FlatBufferModel() {}
220
221} // namespace tflite
222