1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/interpreter_builder.h"
16
17#include <stddef.h>
18#include <stdint.h>
19#include <stdlib.h>
20#include <string.h>
21
22#include <algorithm>
23#include <map>
24#include <memory>
25#include <string>
26#include <utility>
27#include <vector>
28
29#include "flatbuffers/flatbuffers.h" // from @flatbuffers
30#include "tensorflow/lite/c/c_api_types.h"
31#include "tensorflow/lite/core/api/error_reporter.h"
32#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
33#include "tensorflow/lite/core/api/op_resolver.h"
34#include "tensorflow/lite/core/macros.h"
35#include "tensorflow/lite/core/subgraph.h"
36#include "tensorflow/lite/internal/signature_def.h"
37#include "tensorflow/lite/interpreter.h"
38#include "tensorflow/lite/kernels/internal/compatibility.h"
39#include "tensorflow/lite/model_builder.h"
40#include "tensorflow/lite/profiling/platform_profiler.h"
41#include "tensorflow/lite/schema/schema_generated.h"
42#include "tensorflow/lite/schema/schema_utils.h"
43#include "tensorflow/lite/shared_library.h"
44#include "tensorflow/lite/stderr_reporter.h"
45#include "tensorflow/lite/string_type.h"
46#include "tensorflow/lite/util.h"
47#include "tensorflow/lite/version.h"
48
49// aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
50#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
51#if !defined(__ANDROID__) || __ANDROID_API__ >= 28
52// Neither Apple nor Windows provide aligned_alloc.
53#if !defined(__APPLE__) && !defined(_WIN32)
54#define TFLITE_USE_STD_ALIGNED_ALLOC
55#endif
56#endif
57#endif
58
59// TODO(b/139446230): Move to portable platform header.
60#if defined(__ANDROID__)
61#define TFLITE_IS_MOBILE_PLATFORM
62#endif // defined(__ANDROID__)
63
64#if defined(__APPLE__)
65#include "TargetConditionals.h"
66#if TARGET_IPHONE_SIMULATOR
67#define TFLITE_IS_MOBILE_PLATFORM
68#elif TARGET_OS_IPHONE
69#define TFLITE_IS_MOBILE_PLATFORM
70#endif
71#endif // defined(__APPLE__)
72
73namespace tflite {
74
75namespace {
76
77// Ensure that ErrorReporter is non-null.
78ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
79 return e ? e : DefaultErrorReporter();
80}
81
82template <typename T>
83TfLiteStatus Copy(const T* data_ptr, TfLiteIntArray** arr) {
84 if (data_ptr->values() == nullptr) {
85 return kTfLiteError;
86 }
87
88 int size = data_ptr->values()->size();
89 *arr = TfLiteIntArrayCreate(size);
90 for (int i = 0; i < size; i++) {
91 (*arr)->data[i] = static_cast<int>(data_ptr->values()->Get(i));
92 }
93 return kTfLiteOk;
94}
95
96TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
97 TfLiteDimensionMetadata* tgt) {
98 if (src->array_segments() == nullptr || src->array_indices() == nullptr) {
99 return kTfLiteError;
100 }
101 TfLiteStatus status = kTfLiteOk;
102 switch (src->array_segments_type()) {
103 case SparseIndexVector_Int32Vector:
104 status = Copy(src->array_segments_as_Int32Vector(), &tgt->array_segments);
105 break;
106 case SparseIndexVector_Uint16Vector:
107 status =
108 Copy(src->array_segments_as_Uint16Vector(), &tgt->array_segments);
109 break;
110 case SparseIndexVector_Uint8Vector:
111 status = Copy(src->array_segments_as_Uint8Vector(), &tgt->array_segments);
112 break;
113 default:
114 status = kTfLiteError;
115 break;
116 }
117 if (status != kTfLiteOk) return status;
118
119 switch (src->array_indices_type()) {
120 case SparseIndexVector_Int32Vector:
121 return Copy(src->array_indices_as_Int32Vector(), &tgt->array_indices);
122 case SparseIndexVector_Uint16Vector:
123 return Copy(src->array_indices_as_Uint16Vector(), &tgt->array_indices);
124 case SparseIndexVector_Uint8Vector:
125 return Copy(src->array_indices_as_Uint8Vector(), &tgt->array_indices);
126 default:
127 break;
128 }
129 return kTfLiteError;
130}
131
132// Helper that returns std::map that corresponds to vector of TensorMap.
133std::map<std::string, uint32_t> GetMapFromTensorMap(
134 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
135 tensor_map) {
136 if (!tensor_map) return {};
137 std::map<std::string, uint32_t> result;
138 for (const auto tensor : *tensor_map) {
139 if (tensor != nullptr && tensor->name() != nullptr) {
140 result[tensor->name()->c_str()] = tensor->tensor_index();
141 }
142 }
143 return result;
144}
145
146inline bool ShouldCreateLazyDelegateProviders(int num_fp32_tensors) {
147#if defined(XNNPACK_DELEGATE_ENABLE_QS8) || defined(XNNPACK_DELEGATE_ENABLE_QU8)
148 return true;
149#else
150 return num_fp32_tensors > 0;
151#endif
152}
153
154} // namespace
155
156constexpr const char* kEmptyTensorName = "";
157
158// Using weak symbols to create a delegate allows automatic injection of the
159// delegate simply by adding it as a dependency.
160// For flex delegate, see also the strong override in
161// lite/delegates/flex/delegate.cc.
162TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
163 // TF_AcquireFlexDelegate isn't defined on Android, and the following block of
164 // code would have no effect if TF_AcquireFlexDelegate isn't defined, so we
165 // only enable that block for non-Android platforms. Also, on Android 4.4
166 // (Kitkat), the dlsym() implementation has a bug where dlsym() of an unknown
167 // name will result in a SIGFPE, which would crash the process, so it's
168 // important that on Android 4.4 we *don't* call SharedLibrary::GetSymbol
169 // unless the symbol is sure to exist.
170#if !defined(__ANDROID__)
171 auto acquire_flex_delegate_func =
172 reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
173 SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
174 if (acquire_flex_delegate_func) {
175 return acquire_flex_delegate_func();
176 }
177#endif
178
179#if !defined(TFLITE_IS_MOBILE_PLATFORM)
180 // Load TF_AcquireFlexDelegate() from _pywrap_tensorflow_internal.so if it is
181 // available.
182#if defined(_WIN32)
183 const wchar_t* filename_pywrap_tensorflow_internal =
184 L"_pywrap_tensorflow_internal.pyd";
185#elif defined(__APPLE__)
186 const char* filename_pywrap_tensorflow_internal =
187 "python/_pywrap_tensorflow_internal.so";
188#else
189 const char* filename_pywrap_tensorflow_internal =
190 "_pywrap_tensorflow_internal.so";
191#endif
192 void* lib_tf_internal =
193 SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
194#if defined(_WIN32)
195 if (lib_tf_internal == nullptr) {
196 lib_tf_internal = SharedLibrary::LoadLibrary(
197 L"_pywrap_tensorflow_interpreter_wrapper.pyd");
198 }
199#endif
200 if (lib_tf_internal) {
201 acquire_flex_delegate_func =
202 reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
203 SharedLibrary::GetLibrarySymbol(lib_tf_internal,
204 "TF_AcquireFlexDelegate"));
205 if (acquire_flex_delegate_func) {
206 return acquire_flex_delegate_func();
207 }
208 }
209#endif // !defined(TFLITE_IS_MOBILE_PLATFORM)
210
211 return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
212}
213
214InterpreterBuilder::InterpreterBuilder(
215 const FlatBufferModel& model, const OpResolver& op_resolver,
216 const InterpreterOptions* options_experimental)
217 : model_(model.GetModel()),
218 op_resolver_(op_resolver),
219 error_reporter_(ValidateErrorReporter(model.error_reporter())),
220 metadata_(model.ReadAllMetadata()),
221 allocation_(model.allocation()) {
222 if (options_experimental) {
223 options_ = *options_experimental;
224 }
225}
226
227InterpreterBuilder::InterpreterBuilder(
228 const ::tflite::Model* model, const OpResolver& op_resolver,
229 ErrorReporter* error_reporter,
230 const InterpreterOptions* options_experimental)
231 : model_(model),
232 op_resolver_(op_resolver),
233 error_reporter_(ValidateErrorReporter(error_reporter)) {
234 if (options_experimental) {
235 options_ = *options_experimental;
236 }
237}
238
239InterpreterBuilder::~InterpreterBuilder() {}
240
241TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
242 TfLiteStatus status = kTfLiteOk;
243 // Reset state.
244 flatbuffer_op_index_to_registration_.clear();
245 unresolved_custom_ops_.clear();
246
247 auto opcodes = model_->operator_codes();
248 if (!opcodes) {
249 return status;
250 }
251 int num_custom_ops = 0;
252 for (const OperatorCode* opcode : *opcodes) {
253 if (GetBuiltinCode(opcode) == BuiltinOperator_CUSTOM) {
254 num_custom_ops++;
255 }
256 }
257 unresolved_custom_ops_.reserve(num_custom_ops);
258 for (const OperatorCode* opcode : *opcodes) {
259 const TfLiteRegistration* registration = nullptr;
260 status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
261 &registration);
262 if (status != kTfLiteOk) {
263 if (GetBuiltinCode(opcode) != BuiltinOperator_CUSTOM) {
264 return status;
265 }
266 // If it's an unresolved custom op, allow it for now. It might be resolved
267 // by a delegate later.
268 if (!opcode->custom_code()) {
269 error_reporter_->Report(
270 "Operator with CUSTOM builtin_code has no custom_code.\n");
271 return status;
272 }
273 const auto* op_name = opcode->custom_code()->c_str();
274 unresolved_custom_ops_.push_back(CreateUnresolvedCustomOp(op_name));
275 registration = &unresolved_custom_ops_.back();
276 has_flex_op_ |= IsFlexOp(op_name);
277 status = kTfLiteOk;
278 }
279 flatbuffer_op_index_to_registration_.push_back(registration);
280 }
281 return status;
282}
283
284namespace {
285template <class T>
286std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
287 // Initialize shape of tensors with null shape. Empty vectors are converted
288 // to nullptr for models that are constructed via flatbuffers::Pack.
289 if (flat_array == nullptr) {
290 return {};
291 }
292 std::vector<int> ret(flat_array->size());
293 for (int i = 0; i < flat_array->size(); i++) {
294 ret[i] = flat_array->Get(i);
295 }
296 return ret;
297}
298
299// Used to determine how the op data parsing function creates its working space.
300class MallocDataAllocator : public BuiltinDataAllocator {
301 public:
302 void* Allocate(size_t size, size_t alignment_hint) override {
303#ifdef TFLITE_USE_STD_ALIGNED_ALLOC
304 // Ensure that alignment is a power of two and a multiple of sizeof(void *)
305 // and that size is an integral multiple of alignment.
306 size_t used_alignment = std::max(alignment_hint, sizeof(void*));
307 size_t used_size =
308 ((size + used_alignment - 1) / used_alignment) * used_alignment;
309 TFLITE_DCHECK(
310 (used_alignment != 0) &&
311 ((used_alignment & (used_alignment - 1)) == 0)); // is power-of-two
312 return aligned_alloc(used_alignment, used_size);
313#else
314 return malloc(size);
315#endif
316 }
317 void Deallocate(void* data) override { free(data); }
318};
319
320} // namespace
321
322TfLiteStatus InterpreterBuilder::ParseNodes(
323 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
324 Subgraph* subgraph) {
325 TfLiteStatus status = kTfLiteOk;
326
327 // Reduce the number of redundant allocations
328 subgraph->ReserveNodes(operators->size());
329
330 for (int i = 0; i < operators->size(); ++i) {
331 const auto* op = operators->Get(i);
332 int index = op->opcode_index();
333 if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
334 error_reporter_->Report("Missing registration for opcode_index %d\n",
335 index);
336 status = kTfLiteError;
337 continue;
338 }
339
340 const TfLiteRegistration* registration =
341 flatbuffer_op_index_to_registration_[index];
342 if (registration == nullptr) {
343 error_reporter_->Report("Skipping op for opcode_index %d\n", index);
344 status = kTfLiteError;
345 continue;
346 }
347
348 BuiltinOperator op_type =
349 static_cast<BuiltinOperator>(registration->builtin_code);
350
351 if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
352 error_reporter_->Report(
353 "Found builtin operator %s with custom options.\n",
354 EnumNameBuiltinOperator(op_type));
355 }
356
357 if (op_type == BuiltinOperator_CUSTOM) {
358 if (op->custom_options()) {
359 subgraph->AddNodeWithParameters(
360 FlatBufferIntArrayToVector(op->inputs()),
361 FlatBufferIntArrayToVector(op->outputs()),
362 FlatBufferIntArrayToVector(op->intermediates()),
363 reinterpret_cast<const char*>(op->custom_options()->data()),
364 op->custom_options()->size(), nullptr, registration);
365 } else {
366 subgraph->AddNodeWithParameters(
367 FlatBufferIntArrayToVector(op->inputs()),
368 FlatBufferIntArrayToVector(op->outputs()),
369 FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
370 nullptr, registration);
371 }
372 } else {
373 void* builtin_data = nullptr;
374 MallocDataAllocator malloc_allocator;
375 TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
376 &malloc_allocator, &builtin_data));
377 subgraph->AddNodeWithParameters(
378 FlatBufferIntArrayToVector(op->inputs()),
379 FlatBufferIntArrayToVector(op->outputs()),
380 FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
381 builtin_data, registration);
382 }
383 }
384
385 return status;
386}
387
388TfLiteStatus InterpreterBuilder::ParseQuantization(
389 const QuantizationParameters* src_quantization,
390 TfLiteQuantization* quantization, const std::vector<int>& dims) {
391 quantization->type = kTfLiteNoQuantization;
392 if (!src_quantization || !src_quantization->scale() ||
393 src_quantization->scale()->size() == 0) {
394 return kTfLiteOk;
395 }
396 if (!src_quantization->zero_point()) {
397 error_reporter_->Report(
398 "Quantization parameters has non-null scale but null zero_point.");
399 return kTfLiteError;
400 }
401
402 // Ensure that the number of scales matches the number of zero_points.
403 if (src_quantization->scale()->size() !=
404 src_quantization->zero_point()->size()) {
405 error_reporter_->Report(
406 "QuantizationParam has %d zero_point values and %d scale values. Must "
407 "have same number.",
408 src_quantization->zero_point()->size(),
409 src_quantization->scale()->size());
410 return kTfLiteError;
411 }
412
413 const size_t num_scales = src_quantization->scale()->size();
414
415 // Ensure that the quantization dimension is valid.
416 if (src_quantization->quantized_dimension() < 0 ||
417 (!dims.empty() &&
418 src_quantization->quantized_dimension() >= dims.size())) {
419 error_reporter_->Report(
420 "quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
421 src_quantization->quantized_dimension());
422 return kTfLiteError;
423 }
424
425 // Ensure that the number of scales is 1 for per-layer quantization, and
426 // matches number of quantization dimensions for per-axis quantization.
427 if (num_scales != 1 &&
428 (!dims.empty() &&
429 num_scales != dims[src_quantization->quantized_dimension()])) {
430 error_reporter_->Report(
431 "num_scales must be 1 for per-layer quantization, or %d for per-axis "
432 "quantization, but got %d.",
433 dims[src_quantization->quantized_dimension()], num_scales);
434 return kTfLiteError;
435 }
436
437 // Affine-quantization.
438 quantization->type = kTfLiteAffineQuantization;
439 auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
440 malloc(sizeof(TfLiteAffineQuantization)));
441 affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
442 affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales);
443 for (size_t i = 0; i < num_scales; ++i) {
444 affine_quantization->scale->data[i] = src_quantization->scale()->Get(i);
445 affine_quantization->zero_point->data[i] =
446 src_quantization->zero_point()->Get(i);
447 }
448 affine_quantization->quantized_dimension =
449 src_quantization->quantized_dimension();
450 quantization->params = reinterpret_cast<void*>(affine_quantization);
451 return kTfLiteOk;
452}
453
454TfLiteStatus InterpreterBuilder::ParseSparsity(
455 const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity_ptr) {
456 if (!src_sparsity) {
457 return kTfLiteOk;
458 }
459
460 if (src_sparsity->traversal_order() == nullptr ||
461 src_sparsity->dim_metadata() == nullptr) {
462 error_reporter_->Report("Invalid sparsity parameter.");
463 return kTfLiteError;
464 }
465
466 auto* sparsity =
467 reinterpret_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
468 memset(sparsity, 0, sizeof(TfLiteSparsity));
469 *sparsity_ptr = sparsity;
470
471 const size_t traversal_order_size = src_sparsity->traversal_order()->size();
472 sparsity->traversal_order = TfLiteIntArrayCreate(traversal_order_size);
473 for (int i = 0; i < traversal_order_size; i++) {
474 sparsity->traversal_order->data[i] =
475 src_sparsity->traversal_order()->Get(i);
476 }
477
478 if (src_sparsity->block_map()) {
479 const size_t block_map_size = src_sparsity->block_map()->size();
480 sparsity->block_map = TfLiteIntArrayCreate(block_map_size);
481 for (int i = 0; i < block_map_size; i++) {
482 sparsity->block_map->data[i] = src_sparsity->block_map()->Get(i);
483 }
484 }
485
486 const size_t dim_metadata_size = src_sparsity->dim_metadata()->size();
487 sparsity->dim_metadata_size = dim_metadata_size;
488 sparsity->dim_metadata = reinterpret_cast<TfLiteDimensionMetadata*>(
489 malloc(dim_metadata_size * sizeof(TfLiteDimensionMetadata)));
490 memset(sparsity->dim_metadata, 0,
491 dim_metadata_size * sizeof(TfLiteDimensionMetadata));
492
493 for (int i = 0; i < dim_metadata_size; i++) {
494 const auto* src_metadata = src_sparsity->dim_metadata()->Get(i);
495 if (src_metadata->format() != DimensionType_DENSE &&
496 src_metadata->format() != DimensionType_SPARSE_CSR) {
497 TF_LITE_REPORT_ERROR(error_reporter_,
498 "The %dth dimension has unknown type: %d.", i,
499 src_metadata->format());
500 return kTfLiteError;
501 }
502 auto* tgt_metadata = &sparsity->dim_metadata[i];
503
504 tgt_metadata->format =
505 static_cast<TfLiteDimensionType>(src_metadata->format());
506
507 if (tgt_metadata->format == kTfLiteDimDense) {
508 tgt_metadata->dense_size = src_metadata->dense_size();
509 } else {
510 if (ParseSparseIndexVector(src_metadata, tgt_metadata) != kTfLiteOk) {
511 TF_LITE_REPORT_ERROR(
512 error_reporter_,
513 "The %dth sparse dimension has invalid parameters.", i);
514 return kTfLiteError;
515 }
516 }
517 }
518
519 return kTfLiteOk;
520}
521
522TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
523 const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
524 signature_def_list,
525 Interpreter* interpreter) {
526 if (signature_def_list == nullptr || signature_def_list->size() == 0) {
527 return kTfLiteOk;
528 }
529 std::vector<internal::SignatureDef> signature_defs;
530 signature_defs.reserve(signature_def_list->size());
531 for (const auto fb_signature_def : *signature_def_list) {
532 if (fb_signature_def == nullptr) {
533 TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
534 return kTfLiteError;
535 }
536 if (fb_signature_def->signature_key() == nullptr) {
537 TF_LITE_REPORT_ERROR(error_reporter_,
538 "Missing exported method name for SignatureDef");
539 return kTfLiteError;
540 }
541 if (fb_signature_def->inputs() == nullptr) {
542 TF_LITE_REPORT_ERROR(error_reporter_,
543 "NULL SignatureDef inputs for exported method %s",
544 fb_signature_def->signature_key()->c_str());
545 return kTfLiteError;
546 }
547 if (fb_signature_def->outputs() == nullptr) {
548 TF_LITE_REPORT_ERROR(error_reporter_,
549 "NULL SignatureDef outputs for exported method %s",
550 fb_signature_def->signature_key()->c_str());
551 return kTfLiteError;
552 }
553 signature_defs.resize(signature_defs.size() + 1);
554 auto& signature_def = signature_defs.back();
555 signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
556 signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
557 signature_def.signature_key = fb_signature_def->signature_key()->c_str();
558 signature_def.subgraph_index = fb_signature_def->subgraph_index();
559 }
560 interpreter->SetSignatureDef(std::move(signature_defs));
561 return kTfLiteOk;
562}
563
564TfLiteStatus InterpreterBuilder::ParseTensors(
565 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
566 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
567 Subgraph* subgraph) {
568 TfLiteStatus status = kTfLiteOk;
569
570 // A little helper to get the names of inputs and outputs. Note that they
571 // must outlive the subgraph.
572 auto get_name = [](const tflite::Tensor* t) -> const char* {
573 auto name = t->name();
574 if (name) return name->c_str();
575 return kEmptyTensorName;
576 };
577
578 num_fp32_tensors_ = 0;
579 for (int i = 0; i < tensors->size(); ++i) {
580 const auto* tensor = tensors->Get(i);
581 std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
582
583 TfLiteType type;
584 if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
585 kTfLiteOk) {
586 status = kTfLiteError;
587 continue;
588 }
589 if (type == kTfLiteFloat32) {
590 ++num_fp32_tensors_;
591 }
592 auto get_readonly_data = [&](const char** buffer_data,
593 size_t* buffer_size) {
594 // TODO(aselle): Check what happens if we have an unspecified size
595 // constant.
596 *buffer_data = nullptr;
597 if (tensor->buffer() == 0) return kTfLiteOk;
598 if (tensor->buffer() >= buffers->size()) {
599 error_reporter_->Report(
600 "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
601 i, tensor->buffer(), buffers->size());
602 return kTfLiteError;
603 }
604 if (auto* buffer = (*buffers)[tensor->buffer()]) {
605 if (auto* array = buffer->data()) {
606 *buffer_size = array->size();
607 *buffer_data = reinterpret_cast<const char*>(array->data());
608 return kTfLiteOk;
609 }
610 }
611 return kTfLiteOk;
612 };
613 size_t buffer_size = 0;
614 const char* buffer_ptr;
615 TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
616
617 const auto* src_quantization = tensor->quantization();
618 TfLiteQuantization quantization;
619 if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
620 error_reporter_->Report("Tensor %d has invalid quantization parameters.",
621 i);
622 status = kTfLiteError;
623 }
624
625 std::vector<int> dims_signature = {};
626 if (tensor->shape_signature()) {
627 dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
628 }
629
630 bool is_variable = tensor->is_variable();
631 if (buffer_ptr) {
632 if (is_variable) {
633 error_reporter_->Report(
634 "Tensor %d is a variable tensor with buffer. "
635 "It's not supported now.\n",
636 i);
637 status = kTfLiteError;
638 }
639
640 // TODO(b/144999664): Only constant sparse tensor is supported now.
641 const auto* src_sparsity = tensor->sparsity();
642 TfLiteSparsity* sparsity = nullptr;
643 if (ParseSparsity(src_sparsity, &sparsity) != kTfLiteOk) {
644 error_reporter_->Report("Tensor %d has invalid sparsity parameters.",
645 i);
646 status = kTfLiteError;
647 }
648
649 if (subgraph->SetTensorParametersReadOnly(
650 i, type, get_name(tensor), dims, quantization, buffer_ptr,
651 buffer_size, allocation_, sparsity) != kTfLiteOk) {
652 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
653 i);
654 status = kTfLiteError;
655 }
656 } else {
657 if (subgraph->SetTensorParametersReadWrite(
658 i, type, get_name(tensor), dims, quantization, is_variable,
659 dims_signature) != kTfLiteOk) {
660 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
661 i);
662 status = kTfLiteError;
663 }
664 }
665 }
666
667 return status;
668}
669
670TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) {
671 // Apply Flex delegate if applicable.
672 if (has_flex_op_) {
673 if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) {
674 TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegateImpl(
675 // Transfers ownership of flex_delegate to the interpreter.
676 std::move(flex_delegate)));
677 }
678 }
679 for (TfLiteDelegate* delegate : delegates_) {
680 // Note that we DON'T transfer ownership of the delegate to the interpreter.
681 // (Doing that would cause problems if operator() was invoked twice.)
682 TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegateImpl(delegate));
683 }
684 return kTfLiteOk;
685}
686
687TfLiteStatus InterpreterBuilder::SetNumThreads(int num_threads) {
688 if (num_threads < -1) {
689 error_reporter_->Report(
690 "num_threads should be >= 0 or just -1 to let TFLite runtime set the "
691 "value.");
692 return kTfLiteError;
693 }
694 num_threads_ = num_threads;
695 return kTfLiteOk;
696}
697
698TfLiteStatus InterpreterBuilder::operator()(
699 std::unique_ptr<Interpreter>* interpreter, int num_threads) {
700 TfLiteStatus status = SetNumThreads(num_threads);
701 if (status != kTfLiteOk) {
702 interpreter->reset();
703 return status;
704 }
705 return (*this)(interpreter);
706}
707
708TfLiteStatus InterpreterBuilder::operator()(
709 std::unique_ptr<Interpreter>* interpreter) {
710 if (!interpreter) {
711 error_reporter_->Report(
712 "Null output pointer passed to InterpreterBuilder.");
713 return kTfLiteError;
714 }
715
716 // Safe exit by deleting partially created interpreter, to reduce verbosity
717 // on error conditions. Use by return cleanup_on_error();
718 auto cleanup_and_error = [&interpreter]() {
719 interpreter->reset();
720 return kTfLiteError;
721 };
722
723 if (!model_) {
724 error_reporter_->Report("Null pointer passed in as model.");
725 return cleanup_and_error();
726 }
727
728 if (model_->version() != TFLITE_SCHEMA_VERSION) {
729 error_reporter_->Report(
730 "Model provided is schema version %d not equal "
731 "to supported version %d.\n",
732 model_->version(), TFLITE_SCHEMA_VERSION);
733 return cleanup_and_error();
734 }
735
736 if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
737 error_reporter_->Report("Registration failed.\n");
738 return cleanup_and_error();
739 }
740
741 // Flatbuffer model schemas define a list of opcodes independent of the graph.
742 // We first map those to registrations. This reduces string lookups for custom
743 // ops since we only do it once per custom op rather than once per custom op
744 // invocation in the model graph.
745 // Construct interpreter with correct number of tensors and operators.
746 auto* subgraphs = model_->subgraphs();
747 auto* buffers = model_->buffers();
748
749 if (subgraphs->size() == 0) {
750 TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n");
751 return cleanup_and_error();
752 }
753
754 if (!buffers) {
755 TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n");
756 return cleanup_and_error();
757 }
758
759 *interpreter = std::make_unique<Interpreter>(error_reporter_);
760 if (subgraphs->size() > 1) {
761 (*interpreter)->AddSubgraphs(subgraphs->size() - 1);
762 }
763
764 // Set num threads after all the subgraphs are added.
765 (*interpreter)->SetNumThreads(num_threads_);
766
767 // Set Interpreter options
768 (*interpreter)->ApplyOptionsImpl(&options_);
769
770 (*interpreter)
771 ->SetProfilerImpl(tflite::profiling::MaybeCreatePlatformProfiler());
772
773 for (int subgraph_index = 0; subgraph_index < subgraphs->size();
774 ++subgraph_index) {
775 const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
776 tflite::Subgraph* modified_subgraph =
777 (*interpreter)->subgraph(subgraph_index);
778 auto operators = subgraph->operators();
779 auto tensors = subgraph->tensors();
780 if (!tensors) {
781 TF_LITE_REPORT_ERROR(error_reporter_,
782 "Did not get tensors in subgraph %d.\n",
783 subgraph_index);
784 return cleanup_and_error();
785 }
786 if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) {
787 return cleanup_and_error();
788 }
789 // Parse inputs/outputs
790 modified_subgraph->SetInputs(
791 FlatBufferIntArrayToVector(subgraph->inputs()));
792 modified_subgraph->SetOutputs(
793 FlatBufferIntArrayToVector(subgraph->outputs()));
794
795 // Finally setup nodes and tensors
796 // Parse tensors before nodes as ParseNodes checks input tensors for the
797 // nodes.
798 if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
799 return cleanup_and_error();
800 if (operators && ParseNodes(operators, modified_subgraph) != kTfLiteOk)
801 return cleanup_and_error();
802
803 std::vector<int> variables;
804 for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
805 auto* tensor = modified_subgraph->tensor(i);
806 if (tensor->is_variable) {
807 variables.push_back(i);
808 }
809 }
810 modified_subgraph->SetVariables(std::move(variables));
811 if (subgraph->name()) {
812 modified_subgraph->SetName(subgraph->name()->c_str());
813 }
814 }
815
816 if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
817 kTfLiteOk) {
818 return cleanup_and_error();
819 }
820
821 if ((*interpreter)->SetMetadata(metadata_) != kTfLiteOk) {
822 return cleanup_and_error();
823 }
824
825 if (ShouldCreateLazyDelegateProviders(num_fp32_tensors_)) {
826 (*interpreter)->lazy_delegate_providers_ =
827 op_resolver_.GetDelegateCreators();
828 }
829
830 TfLiteStatus status = ApplyDelegates(interpreter->get());
831 if (status != kTfLiteOk) {
832 interpreter->reset();
833 }
834 return status;
835}
836
837void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
838 if (delegate == nullptr) {
839 TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate.");
840 } else {
841 delegates_.push_back(delegate);
842 }
843}
844
845} // namespace tflite
846