1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/interpreter_builder.h" |
16 | |
17 | #include <stddef.h> |
18 | #include <stdint.h> |
19 | #include <stdlib.h> |
20 | #include <string.h> |
21 | |
22 | #include <algorithm> |
23 | #include <map> |
24 | #include <memory> |
25 | #include <string> |
26 | #include <utility> |
27 | #include <vector> |
28 | |
29 | #include "flatbuffers/flatbuffers.h" // from @flatbuffers |
30 | #include "tensorflow/lite/c/c_api_types.h" |
31 | #include "tensorflow/lite/core/api/error_reporter.h" |
32 | #include "tensorflow/lite/core/api/flatbuffer_conversions.h" |
33 | #include "tensorflow/lite/core/api/op_resolver.h" |
34 | #include "tensorflow/lite/core/macros.h" |
35 | #include "tensorflow/lite/core/subgraph.h" |
36 | #include "tensorflow/lite/internal/signature_def.h" |
37 | #include "tensorflow/lite/interpreter.h" |
38 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
39 | #include "tensorflow/lite/model_builder.h" |
40 | #include "tensorflow/lite/profiling/platform_profiler.h" |
41 | #include "tensorflow/lite/schema/schema_generated.h" |
42 | #include "tensorflow/lite/schema/schema_utils.h" |
43 | #include "tensorflow/lite/shared_library.h" |
44 | #include "tensorflow/lite/stderr_reporter.h" |
45 | #include "tensorflow/lite/string_type.h" |
46 | #include "tensorflow/lite/util.h" |
47 | #include "tensorflow/lite/version.h" |
48 | |
49 | // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11. |
50 | #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L |
51 | #if !defined(__ANDROID__) || __ANDROID_API__ >= 28 |
52 | // Neither Apple nor Windows provide aligned_alloc. |
53 | #if !defined(__APPLE__) && !defined(_WIN32) |
54 | #define TFLITE_USE_STD_ALIGNED_ALLOC |
55 | #endif |
56 | #endif |
57 | #endif |
58 | |
59 | // TODO(b/139446230): Move to portable platform header. |
60 | #if defined(__ANDROID__) |
61 | #define TFLITE_IS_MOBILE_PLATFORM |
62 | #endif // defined(__ANDROID__) |
63 | |
64 | #if defined(__APPLE__) |
65 | #include "TargetConditionals.h" |
66 | #if TARGET_IPHONE_SIMULATOR |
67 | #define TFLITE_IS_MOBILE_PLATFORM |
68 | #elif TARGET_OS_IPHONE |
69 | #define TFLITE_IS_MOBILE_PLATFORM |
70 | #endif |
71 | #endif // defined(__APPLE__) |
72 | |
73 | namespace tflite { |
74 | |
75 | namespace { |
76 | |
77 | // Ensure that ErrorReporter is non-null. |
78 | ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { |
79 | return e ? e : DefaultErrorReporter(); |
80 | } |
81 | |
82 | template <typename T> |
83 | TfLiteStatus Copy(const T* data_ptr, TfLiteIntArray** arr) { |
84 | if (data_ptr->values() == nullptr) { |
85 | return kTfLiteError; |
86 | } |
87 | |
88 | int size = data_ptr->values()->size(); |
89 | *arr = TfLiteIntArrayCreate(size); |
90 | for (int i = 0; i < size; i++) { |
91 | (*arr)->data[i] = static_cast<int>(data_ptr->values()->Get(i)); |
92 | } |
93 | return kTfLiteOk; |
94 | } |
95 | |
96 | TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src, |
97 | TfLiteDimensionMetadata* tgt) { |
98 | if (src->array_segments() == nullptr || src->array_indices() == nullptr) { |
99 | return kTfLiteError; |
100 | } |
101 | TfLiteStatus status = kTfLiteOk; |
102 | switch (src->array_segments_type()) { |
103 | case SparseIndexVector_Int32Vector: |
104 | status = Copy(src->array_segments_as_Int32Vector(), &tgt->array_segments); |
105 | break; |
106 | case SparseIndexVector_Uint16Vector: |
107 | status = |
108 | Copy(src->array_segments_as_Uint16Vector(), &tgt->array_segments); |
109 | break; |
110 | case SparseIndexVector_Uint8Vector: |
111 | status = Copy(src->array_segments_as_Uint8Vector(), &tgt->array_segments); |
112 | break; |
113 | default: |
114 | status = kTfLiteError; |
115 | break; |
116 | } |
117 | if (status != kTfLiteOk) return status; |
118 | |
119 | switch (src->array_indices_type()) { |
120 | case SparseIndexVector_Int32Vector: |
121 | return Copy(src->array_indices_as_Int32Vector(), &tgt->array_indices); |
122 | case SparseIndexVector_Uint16Vector: |
123 | return Copy(src->array_indices_as_Uint16Vector(), &tgt->array_indices); |
124 | case SparseIndexVector_Uint8Vector: |
125 | return Copy(src->array_indices_as_Uint8Vector(), &tgt->array_indices); |
126 | default: |
127 | break; |
128 | } |
129 | return kTfLiteError; |
130 | } |
131 | |
132 | // Helper that returns std::map that corresponds to vector of TensorMap. |
133 | std::map<std::string, uint32_t> GetMapFromTensorMap( |
134 | const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>* |
135 | tensor_map) { |
136 | if (!tensor_map) return {}; |
137 | std::map<std::string, uint32_t> result; |
138 | for (const auto tensor : *tensor_map) { |
139 | if (tensor != nullptr && tensor->name() != nullptr) { |
140 | result[tensor->name()->c_str()] = tensor->tensor_index(); |
141 | } |
142 | } |
143 | return result; |
144 | } |
145 | |
146 | inline bool ShouldCreateLazyDelegateProviders(int num_fp32_tensors) { |
147 | #if defined(XNNPACK_DELEGATE_ENABLE_QS8) || defined(XNNPACK_DELEGATE_ENABLE_QU8) |
148 | return true; |
149 | #else |
150 | return num_fp32_tensors > 0; |
151 | #endif |
152 | } |
153 | |
154 | } // namespace |
155 | |
156 | constexpr const char* kEmptyTensorName = "" ; |
157 | |
158 | // Using weak symbols to create a delegate allows automatic injection of the |
159 | // delegate simply by adding it as a dependency. |
160 | // For flex delegate, see also the strong override in |
161 | // lite/delegates/flex/delegate.cc. |
162 | TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { |
163 | // TF_AcquireFlexDelegate isn't defined on Android, and the following block of |
164 | // code would have no effect if TF_AcquireFlexDelegate isn't defined, so we |
165 | // only enable that block for non-Android platforms. Also, on Android 4.4 |
166 | // (Kitkat), the dlsym() implementation has a bug where dlsym() of an unknown |
167 | // name will result in a SIGFPE, which would crash the process, so it's |
168 | // important that on Android 4.4 we *don't* call SharedLibrary::GetSymbol |
169 | // unless the symbol is sure to exist. |
170 | #if !defined(__ANDROID__) |
171 | auto acquire_flex_delegate_func = |
172 | reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>( |
173 | SharedLibrary::GetSymbol("TF_AcquireFlexDelegate" )); |
174 | if (acquire_flex_delegate_func) { |
175 | return acquire_flex_delegate_func(); |
176 | } |
177 | #endif |
178 | |
179 | #if !defined(TFLITE_IS_MOBILE_PLATFORM) |
180 | // Load TF_AcquireFlexDelegate() from _pywrap_tensorflow_internal.so if it is |
181 | // available. |
182 | #if defined(_WIN32) |
183 | const wchar_t* filename_pywrap_tensorflow_internal = |
184 | L"_pywrap_tensorflow_internal.pyd" ; |
185 | #elif defined(__APPLE__) |
186 | const char* filename_pywrap_tensorflow_internal = |
187 | "python/_pywrap_tensorflow_internal.so" ; |
188 | #else |
189 | const char* filename_pywrap_tensorflow_internal = |
190 | "_pywrap_tensorflow_internal.so" ; |
191 | #endif |
192 | void* lib_tf_internal = |
193 | SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal); |
194 | #if defined(_WIN32) |
195 | if (lib_tf_internal == nullptr) { |
196 | lib_tf_internal = SharedLibrary::LoadLibrary( |
197 | L"_pywrap_tensorflow_interpreter_wrapper.pyd" ); |
198 | } |
199 | #endif |
200 | if (lib_tf_internal) { |
201 | acquire_flex_delegate_func = |
202 | reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>( |
203 | SharedLibrary::GetLibrarySymbol(lib_tf_internal, |
204 | "TF_AcquireFlexDelegate" )); |
205 | if (acquire_flex_delegate_func) { |
206 | return acquire_flex_delegate_func(); |
207 | } |
208 | } |
209 | #endif // !defined(TFLITE_IS_MOBILE_PLATFORM) |
210 | |
211 | return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); |
212 | } |
213 | |
214 | InterpreterBuilder::InterpreterBuilder( |
215 | const FlatBufferModel& model, const OpResolver& op_resolver, |
216 | const InterpreterOptions* options_experimental) |
217 | : model_(model.GetModel()), |
218 | op_resolver_(op_resolver), |
219 | error_reporter_(ValidateErrorReporter(model.error_reporter())), |
220 | metadata_(model.ReadAllMetadata()), |
221 | allocation_(model.allocation()) { |
222 | if (options_experimental) { |
223 | options_ = *options_experimental; |
224 | } |
225 | } |
226 | |
227 | InterpreterBuilder::InterpreterBuilder( |
228 | const ::tflite::Model* model, const OpResolver& op_resolver, |
229 | ErrorReporter* error_reporter, |
230 | const InterpreterOptions* options_experimental) |
231 | : model_(model), |
232 | op_resolver_(op_resolver), |
233 | error_reporter_(ValidateErrorReporter(error_reporter)) { |
234 | if (options_experimental) { |
235 | options_ = *options_experimental; |
236 | } |
237 | } |
238 | |
239 | InterpreterBuilder::~InterpreterBuilder() {} |
240 | |
241 | TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { |
242 | TfLiteStatus status = kTfLiteOk; |
243 | // Reset state. |
244 | flatbuffer_op_index_to_registration_.clear(); |
245 | unresolved_custom_ops_.clear(); |
246 | |
247 | auto opcodes = model_->operator_codes(); |
248 | if (!opcodes) { |
249 | return status; |
250 | } |
251 | int num_custom_ops = 0; |
252 | for (const OperatorCode* opcode : *opcodes) { |
253 | if (GetBuiltinCode(opcode) == BuiltinOperator_CUSTOM) { |
254 | num_custom_ops++; |
255 | } |
256 | } |
257 | unresolved_custom_ops_.reserve(num_custom_ops); |
258 | for (const OperatorCode* opcode : *opcodes) { |
259 | const TfLiteRegistration* registration = nullptr; |
260 | status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, |
261 | ®istration); |
262 | if (status != kTfLiteOk) { |
263 | if (GetBuiltinCode(opcode) != BuiltinOperator_CUSTOM) { |
264 | return status; |
265 | } |
266 | // If it's an unresolved custom op, allow it for now. It might be resolved |
267 | // by a delegate later. |
268 | if (!opcode->custom_code()) { |
269 | error_reporter_->Report( |
270 | "Operator with CUSTOM builtin_code has no custom_code.\n" ); |
271 | return status; |
272 | } |
273 | const auto* op_name = opcode->custom_code()->c_str(); |
274 | unresolved_custom_ops_.push_back(CreateUnresolvedCustomOp(op_name)); |
275 | registration = &unresolved_custom_ops_.back(); |
276 | has_flex_op_ |= IsFlexOp(op_name); |
277 | status = kTfLiteOk; |
278 | } |
279 | flatbuffer_op_index_to_registration_.push_back(registration); |
280 | } |
281 | return status; |
282 | } |
283 | |
284 | namespace { |
285 | template <class T> |
286 | std::vector<int> FlatBufferIntArrayToVector(T* flat_array) { |
287 | // Initialize shape of tensors with null shape. Empty vectors are converted |
288 | // to nullptr for models that are constructed via flatbuffers::Pack. |
289 | if (flat_array == nullptr) { |
290 | return {}; |
291 | } |
292 | std::vector<int> ret(flat_array->size()); |
293 | for (int i = 0; i < flat_array->size(); i++) { |
294 | ret[i] = flat_array->Get(i); |
295 | } |
296 | return ret; |
297 | } |
298 | |
299 | // Used to determine how the op data parsing function creates its working space. |
300 | class MallocDataAllocator : public BuiltinDataAllocator { |
301 | public: |
302 | void* Allocate(size_t size, size_t alignment_hint) override { |
303 | #ifdef TFLITE_USE_STD_ALIGNED_ALLOC |
304 | // Ensure that alignment is a power of two and a multiple of sizeof(void *) |
305 | // and that size is an integral multiple of alignment. |
306 | size_t used_alignment = std::max(alignment_hint, sizeof(void*)); |
307 | size_t used_size = |
308 | ((size + used_alignment - 1) / used_alignment) * used_alignment; |
309 | TFLITE_DCHECK( |
310 | (used_alignment != 0) && |
311 | ((used_alignment & (used_alignment - 1)) == 0)); // is power-of-two |
312 | return aligned_alloc(used_alignment, used_size); |
313 | #else |
314 | return malloc(size); |
315 | #endif |
316 | } |
317 | void Deallocate(void* data) override { free(data); } |
318 | }; |
319 | |
320 | } // namespace |
321 | |
322 | TfLiteStatus InterpreterBuilder::ParseNodes( |
323 | const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, |
324 | Subgraph* subgraph) { |
325 | TfLiteStatus status = kTfLiteOk; |
326 | |
327 | // Reduce the number of redundant allocations |
328 | subgraph->ReserveNodes(operators->size()); |
329 | |
330 | for (int i = 0; i < operators->size(); ++i) { |
331 | const auto* op = operators->Get(i); |
332 | int index = op->opcode_index(); |
333 | if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { |
334 | error_reporter_->Report("Missing registration for opcode_index %d\n" , |
335 | index); |
336 | status = kTfLiteError; |
337 | continue; |
338 | } |
339 | |
340 | const TfLiteRegistration* registration = |
341 | flatbuffer_op_index_to_registration_[index]; |
342 | if (registration == nullptr) { |
343 | error_reporter_->Report("Skipping op for opcode_index %d\n" , index); |
344 | status = kTfLiteError; |
345 | continue; |
346 | } |
347 | |
348 | BuiltinOperator op_type = |
349 | static_cast<BuiltinOperator>(registration->builtin_code); |
350 | |
351 | if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { |
352 | error_reporter_->Report( |
353 | "Found builtin operator %s with custom options.\n" , |
354 | EnumNameBuiltinOperator(op_type)); |
355 | } |
356 | |
357 | if (op_type == BuiltinOperator_CUSTOM) { |
358 | if (op->custom_options()) { |
359 | subgraph->AddNodeWithParameters( |
360 | FlatBufferIntArrayToVector(op->inputs()), |
361 | FlatBufferIntArrayToVector(op->outputs()), |
362 | FlatBufferIntArrayToVector(op->intermediates()), |
363 | reinterpret_cast<const char*>(op->custom_options()->data()), |
364 | op->custom_options()->size(), nullptr, registration); |
365 | } else { |
366 | subgraph->AddNodeWithParameters( |
367 | FlatBufferIntArrayToVector(op->inputs()), |
368 | FlatBufferIntArrayToVector(op->outputs()), |
369 | FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0, |
370 | nullptr, registration); |
371 | } |
372 | } else { |
373 | void* builtin_data = nullptr; |
374 | MallocDataAllocator malloc_allocator; |
375 | TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, |
376 | &malloc_allocator, &builtin_data)); |
377 | subgraph->AddNodeWithParameters( |
378 | FlatBufferIntArrayToVector(op->inputs()), |
379 | FlatBufferIntArrayToVector(op->outputs()), |
380 | FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0, |
381 | builtin_data, registration); |
382 | } |
383 | } |
384 | |
385 | return status; |
386 | } |
387 | |
388 | TfLiteStatus InterpreterBuilder::ParseQuantization( |
389 | const QuantizationParameters* src_quantization, |
390 | TfLiteQuantization* quantization, const std::vector<int>& dims) { |
391 | quantization->type = kTfLiteNoQuantization; |
392 | if (!src_quantization || !src_quantization->scale() || |
393 | src_quantization->scale()->size() == 0) { |
394 | return kTfLiteOk; |
395 | } |
396 | if (!src_quantization->zero_point()) { |
397 | error_reporter_->Report( |
398 | "Quantization parameters has non-null scale but null zero_point." ); |
399 | return kTfLiteError; |
400 | } |
401 | |
402 | // Ensure that the number of scales matches the number of zero_points. |
403 | if (src_quantization->scale()->size() != |
404 | src_quantization->zero_point()->size()) { |
405 | error_reporter_->Report( |
406 | "QuantizationParam has %d zero_point values and %d scale values. Must " |
407 | "have same number." , |
408 | src_quantization->zero_point()->size(), |
409 | src_quantization->scale()->size()); |
410 | return kTfLiteError; |
411 | } |
412 | |
413 | const size_t num_scales = src_quantization->scale()->size(); |
414 | |
415 | // Ensure that the quantization dimension is valid. |
416 | if (src_quantization->quantized_dimension() < 0 || |
417 | (!dims.empty() && |
418 | src_quantization->quantized_dimension() >= dims.size())) { |
419 | error_reporter_->Report( |
420 | "quantized_dimension must be in range [0, %d). Was %d." , dims.size(), |
421 | src_quantization->quantized_dimension()); |
422 | return kTfLiteError; |
423 | } |
424 | |
425 | // Ensure that the number of scales is 1 for per-layer quantization, and |
426 | // matches number of quantization dimensions for per-axis quantization. |
427 | if (num_scales != 1 && |
428 | (!dims.empty() && |
429 | num_scales != dims[src_quantization->quantized_dimension()])) { |
430 | error_reporter_->Report( |
431 | "num_scales must be 1 for per-layer quantization, or %d for per-axis " |
432 | "quantization, but got %d." , |
433 | dims[src_quantization->quantized_dimension()], num_scales); |
434 | return kTfLiteError; |
435 | } |
436 | |
437 | // Affine-quantization. |
438 | quantization->type = kTfLiteAffineQuantization; |
439 | auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>( |
440 | malloc(sizeof(TfLiteAffineQuantization))); |
441 | affine_quantization->scale = TfLiteFloatArrayCreate(num_scales); |
442 | affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales); |
443 | for (size_t i = 0; i < num_scales; ++i) { |
444 | affine_quantization->scale->data[i] = src_quantization->scale()->Get(i); |
445 | affine_quantization->zero_point->data[i] = |
446 | src_quantization->zero_point()->Get(i); |
447 | } |
448 | affine_quantization->quantized_dimension = |
449 | src_quantization->quantized_dimension(); |
450 | quantization->params = reinterpret_cast<void*>(affine_quantization); |
451 | return kTfLiteOk; |
452 | } |
453 | |
454 | TfLiteStatus InterpreterBuilder::ParseSparsity( |
455 | const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity_ptr) { |
456 | if (!src_sparsity) { |
457 | return kTfLiteOk; |
458 | } |
459 | |
460 | if (src_sparsity->traversal_order() == nullptr || |
461 | src_sparsity->dim_metadata() == nullptr) { |
462 | error_reporter_->Report("Invalid sparsity parameter." ); |
463 | return kTfLiteError; |
464 | } |
465 | |
466 | auto* sparsity = |
467 | reinterpret_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity))); |
468 | memset(sparsity, 0, sizeof(TfLiteSparsity)); |
469 | *sparsity_ptr = sparsity; |
470 | |
471 | const size_t traversal_order_size = src_sparsity->traversal_order()->size(); |
472 | sparsity->traversal_order = TfLiteIntArrayCreate(traversal_order_size); |
473 | for (int i = 0; i < traversal_order_size; i++) { |
474 | sparsity->traversal_order->data[i] = |
475 | src_sparsity->traversal_order()->Get(i); |
476 | } |
477 | |
478 | if (src_sparsity->block_map()) { |
479 | const size_t block_map_size = src_sparsity->block_map()->size(); |
480 | sparsity->block_map = TfLiteIntArrayCreate(block_map_size); |
481 | for (int i = 0; i < block_map_size; i++) { |
482 | sparsity->block_map->data[i] = src_sparsity->block_map()->Get(i); |
483 | } |
484 | } |
485 | |
486 | const size_t dim_metadata_size = src_sparsity->dim_metadata()->size(); |
487 | sparsity->dim_metadata_size = dim_metadata_size; |
488 | sparsity->dim_metadata = reinterpret_cast<TfLiteDimensionMetadata*>( |
489 | malloc(dim_metadata_size * sizeof(TfLiteDimensionMetadata))); |
490 | memset(sparsity->dim_metadata, 0, |
491 | dim_metadata_size * sizeof(TfLiteDimensionMetadata)); |
492 | |
493 | for (int i = 0; i < dim_metadata_size; i++) { |
494 | const auto* src_metadata = src_sparsity->dim_metadata()->Get(i); |
495 | if (src_metadata->format() != DimensionType_DENSE && |
496 | src_metadata->format() != DimensionType_SPARSE_CSR) { |
497 | TF_LITE_REPORT_ERROR(error_reporter_, |
498 | "The %dth dimension has unknown type: %d." , i, |
499 | src_metadata->format()); |
500 | return kTfLiteError; |
501 | } |
502 | auto* tgt_metadata = &sparsity->dim_metadata[i]; |
503 | |
504 | tgt_metadata->format = |
505 | static_cast<TfLiteDimensionType>(src_metadata->format()); |
506 | |
507 | if (tgt_metadata->format == kTfLiteDimDense) { |
508 | tgt_metadata->dense_size = src_metadata->dense_size(); |
509 | } else { |
510 | if (ParseSparseIndexVector(src_metadata, tgt_metadata) != kTfLiteOk) { |
511 | TF_LITE_REPORT_ERROR( |
512 | error_reporter_, |
513 | "The %dth sparse dimension has invalid parameters." , i); |
514 | return kTfLiteError; |
515 | } |
516 | } |
517 | } |
518 | |
519 | return kTfLiteOk; |
520 | } |
521 | |
522 | TfLiteStatus InterpreterBuilder::ParseSignatureDefs( |
523 | const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>* |
524 | signature_def_list, |
525 | Interpreter* interpreter) { |
526 | if (signature_def_list == nullptr || signature_def_list->size() == 0) { |
527 | return kTfLiteOk; |
528 | } |
529 | std::vector<internal::SignatureDef> signature_defs; |
530 | signature_defs.reserve(signature_def_list->size()); |
531 | for (const auto fb_signature_def : *signature_def_list) { |
532 | if (fb_signature_def == nullptr) { |
533 | TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model." ); |
534 | return kTfLiteError; |
535 | } |
536 | if (fb_signature_def->signature_key() == nullptr) { |
537 | TF_LITE_REPORT_ERROR(error_reporter_, |
538 | "Missing exported method name for SignatureDef" ); |
539 | return kTfLiteError; |
540 | } |
541 | if (fb_signature_def->inputs() == nullptr) { |
542 | TF_LITE_REPORT_ERROR(error_reporter_, |
543 | "NULL SignatureDef inputs for exported method %s" , |
544 | fb_signature_def->signature_key()->c_str()); |
545 | return kTfLiteError; |
546 | } |
547 | if (fb_signature_def->outputs() == nullptr) { |
548 | TF_LITE_REPORT_ERROR(error_reporter_, |
549 | "NULL SignatureDef outputs for exported method %s" , |
550 | fb_signature_def->signature_key()->c_str()); |
551 | return kTfLiteError; |
552 | } |
553 | signature_defs.resize(signature_defs.size() + 1); |
554 | auto& signature_def = signature_defs.back(); |
555 | signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs()); |
556 | signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs()); |
557 | signature_def.signature_key = fb_signature_def->signature_key()->c_str(); |
558 | signature_def.subgraph_index = fb_signature_def->subgraph_index(); |
559 | } |
560 | interpreter->SetSignatureDef(std::move(signature_defs)); |
561 | return kTfLiteOk; |
562 | } |
563 | |
564 | TfLiteStatus InterpreterBuilder::ParseTensors( |
565 | const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, |
566 | const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, |
567 | Subgraph* subgraph) { |
568 | TfLiteStatus status = kTfLiteOk; |
569 | |
570 | // A little helper to get the names of inputs and outputs. Note that they |
571 | // must outlive the subgraph. |
572 | auto get_name = [](const tflite::Tensor* t) -> const char* { |
573 | auto name = t->name(); |
574 | if (name) return name->c_str(); |
575 | return kEmptyTensorName; |
576 | }; |
577 | |
578 | num_fp32_tensors_ = 0; |
579 | for (int i = 0; i < tensors->size(); ++i) { |
580 | const auto* tensor = tensors->Get(i); |
581 | std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape()); |
582 | |
583 | TfLiteType type; |
584 | if (ConvertTensorType(tensor->type(), &type, error_reporter_) != |
585 | kTfLiteOk) { |
586 | status = kTfLiteError; |
587 | continue; |
588 | } |
589 | if (type == kTfLiteFloat32) { |
590 | ++num_fp32_tensors_; |
591 | } |
592 | auto get_readonly_data = [&](const char** buffer_data, |
593 | size_t* buffer_size) { |
594 | // TODO(aselle): Check what happens if we have an unspecified size |
595 | // constant. |
596 | *buffer_data = nullptr; |
597 | if (tensor->buffer() == 0) return kTfLiteOk; |
598 | if (tensor->buffer() >= buffers->size()) { |
599 | error_reporter_->Report( |
600 | "Tensor %d specifies out of range buffer %d (only %d buffers).\n" , |
601 | i, tensor->buffer(), buffers->size()); |
602 | return kTfLiteError; |
603 | } |
604 | if (auto* buffer = (*buffers)[tensor->buffer()]) { |
605 | if (auto* array = buffer->data()) { |
606 | *buffer_size = array->size(); |
607 | *buffer_data = reinterpret_cast<const char*>(array->data()); |
608 | return kTfLiteOk; |
609 | } |
610 | } |
611 | return kTfLiteOk; |
612 | }; |
613 | size_t buffer_size = 0; |
614 | const char* buffer_ptr; |
615 | TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); |
616 | |
617 | const auto* src_quantization = tensor->quantization(); |
618 | TfLiteQuantization quantization; |
619 | if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) { |
620 | error_reporter_->Report("Tensor %d has invalid quantization parameters." , |
621 | i); |
622 | status = kTfLiteError; |
623 | } |
624 | |
625 | std::vector<int> dims_signature = {}; |
626 | if (tensor->shape_signature()) { |
627 | dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature()); |
628 | } |
629 | |
630 | bool is_variable = tensor->is_variable(); |
631 | if (buffer_ptr) { |
632 | if (is_variable) { |
633 | error_reporter_->Report( |
634 | "Tensor %d is a variable tensor with buffer. " |
635 | "It's not supported now.\n" , |
636 | i); |
637 | status = kTfLiteError; |
638 | } |
639 | |
640 | // TODO(b/144999664): Only constant sparse tensor is supported now. |
641 | const auto* src_sparsity = tensor->sparsity(); |
642 | TfLiteSparsity* sparsity = nullptr; |
643 | if (ParseSparsity(src_sparsity, &sparsity) != kTfLiteOk) { |
644 | error_reporter_->Report("Tensor %d has invalid sparsity parameters." , |
645 | i); |
646 | status = kTfLiteError; |
647 | } |
648 | |
649 | if (subgraph->SetTensorParametersReadOnly( |
650 | i, type, get_name(tensor), dims, quantization, buffer_ptr, |
651 | buffer_size, allocation_, sparsity) != kTfLiteOk) { |
652 | error_reporter_->Report("Tensor %d is invalidly specified in schema.\n" , |
653 | i); |
654 | status = kTfLiteError; |
655 | } |
656 | } else { |
657 | if (subgraph->SetTensorParametersReadWrite( |
658 | i, type, get_name(tensor), dims, quantization, is_variable, |
659 | dims_signature) != kTfLiteOk) { |
660 | error_reporter_->Report("Tensor %d is invalidly specified in schema.\n" , |
661 | i); |
662 | status = kTfLiteError; |
663 | } |
664 | } |
665 | } |
666 | |
667 | return status; |
668 | } |
669 | |
670 | TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) { |
671 | // Apply Flex delegate if applicable. |
672 | if (has_flex_op_) { |
673 | if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) { |
674 | TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegateImpl( |
675 | // Transfers ownership of flex_delegate to the interpreter. |
676 | std::move(flex_delegate))); |
677 | } |
678 | } |
679 | for (TfLiteDelegate* delegate : delegates_) { |
680 | // Note that we DON'T transfer ownership of the delegate to the interpreter. |
681 | // (Doing that would cause problems if operator() was invoked twice.) |
682 | TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegateImpl(delegate)); |
683 | } |
684 | return kTfLiteOk; |
685 | } |
686 | |
687 | TfLiteStatus InterpreterBuilder::SetNumThreads(int num_threads) { |
688 | if (num_threads < -1) { |
689 | error_reporter_->Report( |
690 | "num_threads should be >= 0 or just -1 to let TFLite runtime set the " |
691 | "value." ); |
692 | return kTfLiteError; |
693 | } |
694 | num_threads_ = num_threads; |
695 | return kTfLiteOk; |
696 | } |
697 | |
698 | TfLiteStatus InterpreterBuilder::operator()( |
699 | std::unique_ptr<Interpreter>* interpreter, int num_threads) { |
700 | TfLiteStatus status = SetNumThreads(num_threads); |
701 | if (status != kTfLiteOk) { |
702 | interpreter->reset(); |
703 | return status; |
704 | } |
705 | return (*this)(interpreter); |
706 | } |
707 | |
708 | TfLiteStatus InterpreterBuilder::operator()( |
709 | std::unique_ptr<Interpreter>* interpreter) { |
710 | if (!interpreter) { |
711 | error_reporter_->Report( |
712 | "Null output pointer passed to InterpreterBuilder." ); |
713 | return kTfLiteError; |
714 | } |
715 | |
716 | // Safe exit by deleting partially created interpreter, to reduce verbosity |
717 | // on error conditions. Use by return cleanup_on_error(); |
718 | auto cleanup_and_error = [&interpreter]() { |
719 | interpreter->reset(); |
720 | return kTfLiteError; |
721 | }; |
722 | |
723 | if (!model_) { |
724 | error_reporter_->Report("Null pointer passed in as model." ); |
725 | return cleanup_and_error(); |
726 | } |
727 | |
728 | if (model_->version() != TFLITE_SCHEMA_VERSION) { |
729 | error_reporter_->Report( |
730 | "Model provided is schema version %d not equal " |
731 | "to supported version %d.\n" , |
732 | model_->version(), TFLITE_SCHEMA_VERSION); |
733 | return cleanup_and_error(); |
734 | } |
735 | |
736 | if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { |
737 | error_reporter_->Report("Registration failed.\n" ); |
738 | return cleanup_and_error(); |
739 | } |
740 | |
741 | // Flatbuffer model schemas define a list of opcodes independent of the graph. |
742 | // We first map those to registrations. This reduces string lookups for custom |
743 | // ops since we only do it once per custom op rather than once per custom op |
744 | // invocation in the model graph. |
745 | // Construct interpreter with correct number of tensors and operators. |
746 | auto* subgraphs = model_->subgraphs(); |
747 | auto* buffers = model_->buffers(); |
748 | |
749 | if (subgraphs->size() == 0) { |
750 | TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n" ); |
751 | return cleanup_and_error(); |
752 | } |
753 | |
754 | if (!buffers) { |
755 | TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n" ); |
756 | return cleanup_and_error(); |
757 | } |
758 | |
759 | *interpreter = std::make_unique<Interpreter>(error_reporter_); |
760 | if (subgraphs->size() > 1) { |
761 | (*interpreter)->AddSubgraphs(subgraphs->size() - 1); |
762 | } |
763 | |
764 | // Set num threads after all the subgraphs are added. |
765 | (*interpreter)->SetNumThreads(num_threads_); |
766 | |
767 | // Set Interpreter options |
768 | (*interpreter)->ApplyOptionsImpl(&options_); |
769 | |
770 | (*interpreter) |
771 | ->SetProfilerImpl(tflite::profiling::MaybeCreatePlatformProfiler()); |
772 | |
773 | for (int subgraph_index = 0; subgraph_index < subgraphs->size(); |
774 | ++subgraph_index) { |
775 | const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index]; |
776 | tflite::Subgraph* modified_subgraph = |
777 | (*interpreter)->subgraph(subgraph_index); |
778 | auto operators = subgraph->operators(); |
779 | auto tensors = subgraph->tensors(); |
780 | if (!tensors) { |
781 | TF_LITE_REPORT_ERROR(error_reporter_, |
782 | "Did not get tensors in subgraph %d.\n" , |
783 | subgraph_index); |
784 | return cleanup_and_error(); |
785 | } |
786 | if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) { |
787 | return cleanup_and_error(); |
788 | } |
789 | // Parse inputs/outputs |
790 | modified_subgraph->SetInputs( |
791 | FlatBufferIntArrayToVector(subgraph->inputs())); |
792 | modified_subgraph->SetOutputs( |
793 | FlatBufferIntArrayToVector(subgraph->outputs())); |
794 | |
795 | // Finally setup nodes and tensors |
796 | // Parse tensors before nodes as ParseNodes checks input tensors for the |
797 | // nodes. |
798 | if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk) |
799 | return cleanup_and_error(); |
800 | if (operators && ParseNodes(operators, modified_subgraph) != kTfLiteOk) |
801 | return cleanup_and_error(); |
802 | |
803 | std::vector<int> variables; |
804 | for (int i = 0; i < modified_subgraph->tensors_size(); ++i) { |
805 | auto* tensor = modified_subgraph->tensor(i); |
806 | if (tensor->is_variable) { |
807 | variables.push_back(i); |
808 | } |
809 | } |
810 | modified_subgraph->SetVariables(std::move(variables)); |
811 | if (subgraph->name()) { |
812 | modified_subgraph->SetName(subgraph->name()->c_str()); |
813 | } |
814 | } |
815 | |
816 | if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) != |
817 | kTfLiteOk) { |
818 | return cleanup_and_error(); |
819 | } |
820 | |
821 | if ((*interpreter)->SetMetadata(metadata_) != kTfLiteOk) { |
822 | return cleanup_and_error(); |
823 | } |
824 | |
825 | if (ShouldCreateLazyDelegateProviders(num_fp32_tensors_)) { |
826 | (*interpreter)->lazy_delegate_providers_ = |
827 | op_resolver_.GetDelegateCreators(); |
828 | } |
829 | |
830 | TfLiteStatus status = ApplyDelegates(interpreter->get()); |
831 | if (status != kTfLiteOk) { |
832 | interpreter->reset(); |
833 | } |
834 | return status; |
835 | } |
836 | |
837 | void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) { |
838 | if (delegate == nullptr) { |
839 | TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate." ); |
840 | } else { |
841 | delegates_.push_back(delegate); |
842 | } |
843 | } |
844 | |
845 | } // namespace tflite |
846 | |