1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16
17#include <stdarg.h>
18
19#include <cstring>
20#include <functional>
21#include <memory>
22#include <sstream>
23#include <string>
24#include <utility>
25
26#include "absl/memory/memory.h"
27#include "absl/strings/str_format.h"
28#include "tensorflow/lite/c/common.h"
29#include "tensorflow/lite/core/api/error_reporter.h"
30#include "tensorflow/lite/core/api/op_resolver.h"
31#include "tensorflow/lite/interpreter.h"
32#include "tensorflow/lite/kernels/internal/compatibility.h"
33#include "tensorflow/lite/kernels/register.h"
34#include "tensorflow/lite/kernels/register_ref.h"
35#include "tensorflow/lite/model.h"
36#include "tensorflow/lite/mutable_op_resolver.h"
37#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
38#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
39#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
40#include "tensorflow/lite/shared_library.h"
41#include "tensorflow/lite/string_util.h"
42#include "tensorflow/lite/util.h"
43
44#define TFLITE_PY_CHECK(x) \
45 if ((x) != kTfLiteOk) { \
46 return error_reporter_->exception(); \
47 }
48
49#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
50 if (i >= interpreter_->tensors_size() || i < 0) { \
51 PyErr_Format(PyExc_ValueError, \
52 "Invalid tensor index %d exceeds max tensor index %lu", i, \
53 interpreter_->tensors_size()); \
54 return nullptr; \
55 }
56
57#define TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index) \
58 if (i >= interpreter_->subgraph(subgraph_index)->tensors_size() || i < 0) { \
59 PyErr_Format(PyExc_ValueError, \
60 "Invalid tensor index %d exceeds max tensor index %lu", i, \
61 interpreter_->subgraph(subgraph_index)->tensors_size()); \
62 return nullptr; \
63 }
64
65#define TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(i) \
66 if (i >= interpreter_->subgraphs_size() || i < 0) { \
67 PyErr_Format(PyExc_ValueError, \
68 "Invalid subgraph index %d exceeds max subgraph index %lu", \
69 i, interpreter_->subgraphs_size()); \
70 return nullptr; \
71 }
72
73#define TFLITE_PY_NODES_BOUNDS_CHECK(i) \
74 if (i >= interpreter_->nodes_size() || i < 0) { \
75 PyErr_Format(PyExc_ValueError, "Invalid node index"); \
76 return nullptr; \
77 }
78
79#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
80 if (!interpreter_) { \
81 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
82 return nullptr; \
83 }
84
85namespace tflite {
86namespace interpreter_wrapper {
87
88namespace {
89
90using python_utils::PyDecrefDeleter;
91
92std::unique_ptr<Interpreter> CreateInterpreter(
93 const InterpreterWrapper::Model* model,
94 const tflite::MutableOpResolver& resolver, bool preserve_all_tensors) {
95 if (!model) {
96 return nullptr;
97 }
98
99 ::tflite::python::ImportNumpy();
100
101 std::unique_ptr<Interpreter> interpreter;
102 InterpreterOptions options;
103 options.SetPreserveAllTensors(preserve_all_tensors);
104 InterpreterBuilder builder(*model, resolver, &options);
105 if (builder(&interpreter) != kTfLiteOk) {
106 return nullptr;
107 }
108 return interpreter;
109}
110
111PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
112 void* pydata = malloc(size * sizeof(float));
113 memcpy(pydata, data, size * sizeof(float));
114 PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
115 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
116 return obj;
117}
118
119PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
120 void* pydata = malloc(size * sizeof(int));
121 memcpy(pydata, data, size * sizeof(int));
122 PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
123 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
124 return obj;
125}
126
127PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
128 PyObject* result = PyTuple_New(2);
129 PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
130 PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
131 return result;
132}
133
134PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
135 PyObject* result = PyDict_New();
136 PyDict_SetItemString(result, "traversal_order",
137 PyArrayFromIntVector(param.traversal_order->data,
138 param.traversal_order->size));
139 PyDict_SetItemString(
140 result, "block_map",
141 PyArrayFromIntVector(param.block_map->data, param.block_map->size));
142 PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
143 for (int i = 0; i < param.dim_metadata_size; i++) {
144 PyObject* dim_metadata_i = PyDict_New();
145 if (param.dim_metadata[i].format == kTfLiteDimDense) {
146 PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
147 PyDict_SetItemString(dim_metadata_i, "dense_size",
148 PyLong_FromSize_t(param.dim_metadata[i].dense_size));
149 } else {
150 PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
151 const auto* array_segments = param.dim_metadata[i].array_segments;
152 const auto* array_indices = param.dim_metadata[i].array_indices;
153 PyDict_SetItemString(
154 dim_metadata_i, "array_segments",
155 PyArrayFromIntVector(array_segments->data, array_segments->size));
156 PyDict_SetItemString(
157 dim_metadata_i, "array_indices",
158 PyArrayFromIntVector(array_indices->data, array_indices->size));
159 }
160 PyList_SetItem(dim_metadata, i, dim_metadata_i);
161 }
162 PyDict_SetItemString(result, "dim_metadata", dim_metadata);
163 return result;
164}
165
166bool RegisterCustomOpByName(const char* registerer_name,
167 tflite::MutableOpResolver* resolver,
168 std::string* error_msg) {
169 // Registerer functions take a pointer to a BuiltinOpResolver as an input
170 // parameter and return void.
171 // TODO(b/137576229): We should implement this functionality in a more
172 // principled way.
173 typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
174
175 // Look for the Registerer function by name.
176 RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
177 SharedLibrary::GetSymbol(registerer_name));
178
179 // Fail in an informative way if the function was not found.
180 if (registerer == nullptr) {
181 *error_msg =
182 absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
183 registerer_name, SharedLibrary::GetError());
184 return false;
185 }
186
187 // Call the registerer with the resolver.
188 registerer(resolver);
189 return true;
190}
191
192} // namespace
193
194static constexpr int kBuiltinOpResolver = 1;
195static constexpr int kBuiltinRefOpResolver = 2;
196static constexpr int kBuiltinOpResolverWithoutDefaultDelegates = 3;
197
198InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
199 std::unique_ptr<InterpreterWrapper::Model> model, int op_resolver_id,
200 std::unique_ptr<PythonErrorReporter> error_reporter,
201 const std::vector<std::string>& registerers_by_name,
202 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
203 std::string* error_msg, bool preserve_all_tensors) {
204 if (!model) {
205 *error_msg = error_reporter->message();
206 return nullptr;
207 }
208
209 std::unique_ptr<tflite::MutableOpResolver> resolver;
210 switch (op_resolver_id) {
211 case kBuiltinOpResolver:
212 resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
213 break;
214 case kBuiltinRefOpResolver:
215 resolver = std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
216 break;
217 case kBuiltinOpResolverWithoutDefaultDelegates:
218 resolver = std::make_unique<
219 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
220 break;
221 default:
222 // This should not never happen because the eventual caller in
223 // interpreter.py should have passed a valid id here.
224 TFLITE_DCHECK(false);
225 return nullptr;
226 }
227
228 for (const auto& registerer : registerers_by_name) {
229 if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
230 return nullptr;
231 }
232 for (const auto& registerer : registerers_by_func) {
233 registerer(reinterpret_cast<uintptr_t>(resolver.get()));
234 }
235 auto interpreter =
236 CreateInterpreter(model.get(), *resolver, preserve_all_tensors);
237 if (!interpreter) {
238 *error_msg = error_reporter->message();
239 return nullptr;
240 }
241
242 InterpreterWrapper* wrapper =
243 new InterpreterWrapper(std::move(model), std::move(error_reporter),
244 std::move(resolver), std::move(interpreter));
245 return wrapper;
246}
247
248InterpreterWrapper::InterpreterWrapper(
249 std::unique_ptr<InterpreterWrapper::Model> model,
250 std::unique_ptr<PythonErrorReporter> error_reporter,
251 std::unique_ptr<tflite::MutableOpResolver> resolver,
252 std::unique_ptr<Interpreter> interpreter)
253 : model_(std::move(model)),
254 error_reporter_(std::move(error_reporter)),
255 resolver_(std::move(resolver)),
256 interpreter_(std::move(interpreter)) {}
257
258InterpreterWrapper::~InterpreterWrapper() {}
259
260// LINT.IfChange
261static constexpr int kUndeterminedSubgraphIndex = -1;
262// LINT.ThenChange(//tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc)
263PyObject* InterpreterWrapper::AllocateTensors(int subgraph_index) {
264 TFLITE_PY_ENSURE_VALID_INTERPRETER();
265 if (subgraph_index == kUndeterminedSubgraphIndex) {
266 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
267 } else {
268 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
269 TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)->AllocateTensors());
270 }
271 Py_RETURN_NONE;
272}
273
274PyObject* InterpreterWrapper::Invoke(int subgraph_index) {
275 TFLITE_PY_ENSURE_VALID_INTERPRETER();
276 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
277
278 // Release the GIL so that we can run multiple interpreters in parallel
279 TfLiteStatus status_code = kTfLiteOk;
280 Py_BEGIN_ALLOW_THREADS; // To return can happen between this and end!
281 tflite::Subgraph* subgraph = interpreter_->subgraph(subgraph_index);
282 status_code = subgraph->Invoke();
283
284 if (!interpreter_->allow_buffer_handle_output_) {
285 for (int tensor_index : subgraph->outputs()) {
286 subgraph->EnsureTensorDataIsReadable(tensor_index);
287 }
288 }
289 Py_END_ALLOW_THREADS;
290
291 TFLITE_PY_CHECK(
292 status_code); // don't move this into the Py_BEGIN/Py_End block
293
294 Py_RETURN_NONE;
295}
296
297PyObject* InterpreterWrapper::InputIndices() const {
298 TFLITE_PY_ENSURE_VALID_INTERPRETER();
299 PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
300 interpreter_->inputs().size());
301
302 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
303}
304
305PyObject* InterpreterWrapper::OutputIndices() const {
306 PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
307 interpreter_->outputs().size());
308
309 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
310}
311
312PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
313 TFLITE_PY_ENSURE_VALID_INTERPRETER();
314
315 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
316 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
317 if (!array_safe) {
318 PyErr_SetString(PyExc_ValueError,
319 "Failed to convert numpy value into readable tensor.");
320 return nullptr;
321 }
322
323 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
324
325 if (PyArray_NDIM(array) != 1) {
326 PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
327 PyArray_NDIM(array));
328 return nullptr;
329 }
330
331 if (PyArray_TYPE(array) != NPY_INT32) {
332 PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
333 PyArray_TYPE(array));
334 return nullptr;
335 }
336
337 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
338 NPY_ARRAY_OWNDATA);
339 return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
340}
341
342PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
343 bool strict,
344 int subgraph_index) {
345 TFLITE_PY_ENSURE_VALID_INTERPRETER();
346 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
347
348 PyArrayObject* array =
349 reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
350 if (array == nullptr) {
351 return nullptr;
352 }
353
354 std::vector<int> dims(PyArray_SHAPE(array)[0]);
355 memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
356
357 if (strict) {
358 TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)
359 ->ResizeInputTensorStrict(i, dims));
360 } else {
361 TFLITE_PY_CHECK(
362 interpreter_->subgraph(subgraph_index)->ResizeInputTensor(i, dims));
363 }
364 Py_RETURN_NONE;
365}
366
367int InterpreterWrapper::NumTensors() const {
368 if (!interpreter_) {
369 return 0;
370 }
371 return interpreter_->tensors_size();
372}
373
374std::string InterpreterWrapper::TensorName(int i) const {
375 if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
376 return "";
377 }
378
379 const TfLiteTensor* tensor = interpreter_->tensor(i);
380 return tensor->name ? tensor->name : "";
381}
382
383PyObject* InterpreterWrapper::TensorType(int i) const {
384 TFLITE_PY_ENSURE_VALID_INTERPRETER();
385 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
386
387 const TfLiteTensor* tensor = interpreter_->tensor(i);
388 if (tensor->type == kTfLiteNoType) {
389 PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
390 return nullptr;
391 }
392
393 int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
394 if (code == -1) {
395 PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
396 return nullptr;
397 }
398 return PyArray_TypeObjectFromType(code);
399}
400
401PyObject* InterpreterWrapper::TensorSize(int i) const {
402 TFLITE_PY_ENSURE_VALID_INTERPRETER();
403 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
404
405 const TfLiteTensor* tensor = interpreter_->tensor(i);
406 if (tensor->dims == nullptr) {
407 PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
408 return nullptr;
409 }
410 PyObject* np_array =
411 PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
412
413 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
414}
415
416PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
417 TFLITE_PY_ENSURE_VALID_INTERPRETER();
418 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
419
420 const TfLiteTensor* tensor = interpreter_->tensor(i);
421 const int32_t* size_signature_data = nullptr;
422 int32_t size_signature_size = 0;
423 if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) {
424 size_signature_data = tensor->dims_signature->data;
425 size_signature_size = tensor->dims_signature->size;
426 } else {
427 size_signature_data = tensor->dims->data;
428 size_signature_size = tensor->dims->size;
429 }
430 PyObject* np_array =
431 PyArrayFromIntVector(size_signature_data, size_signature_size);
432
433 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
434}
435
436PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
437 TFLITE_PY_ENSURE_VALID_INTERPRETER();
438 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
439 const TfLiteTensor* tensor = interpreter_->tensor(i);
440 if (tensor->sparsity == nullptr) {
441 return PyDict_New();
442 }
443
444 return PyDictFromSparsityParam(*tensor->sparsity);
445}
446
447PyObject* InterpreterWrapper::TensorQuantization(int i) const {
448 TFLITE_PY_ENSURE_VALID_INTERPRETER();
449 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
450 const TfLiteTensor* tensor = interpreter_->tensor(i);
451 return PyTupleFromQuantizationParam(tensor->params);
452}
453
454PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
455 TFLITE_PY_ENSURE_VALID_INTERPRETER();
456 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
457 const TfLiteTensor* tensor = interpreter_->tensor(i);
458 const TfLiteQuantization quantization = tensor->quantization;
459 float* scales_data = nullptr;
460 int32_t* zero_points_data = nullptr;
461 int32_t scales_size = 0;
462 int32_t zero_points_size = 0;
463 int32_t quantized_dimension = 0;
464 if (quantization.type == kTfLiteAffineQuantization) {
465 const TfLiteAffineQuantization* q_params =
466 reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
467 if (q_params->scale) {
468 scales_data = q_params->scale->data;
469 scales_size = q_params->scale->size;
470 }
471 if (q_params->zero_point) {
472 zero_points_data = q_params->zero_point->data;
473 zero_points_size = q_params->zero_point->size;
474 }
475 quantized_dimension = q_params->quantized_dimension;
476 }
477 PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
478 PyObject* zero_points_array =
479 PyArrayFromIntVector(zero_points_data, zero_points_size);
480
481 PyObject* result = PyTuple_New(3);
482 PyTuple_SET_ITEM(result, 0, scales_array);
483 PyTuple_SET_ITEM(result, 1, zero_points_array);
484 PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
485 return result;
486}
487
488PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value,
489 int subgraph_index) {
490 TFLITE_PY_ENSURE_VALID_INTERPRETER();
491 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
492 TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index);
493
494 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
495 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
496 if (!array_safe) {
497 PyErr_SetString(PyExc_ValueError,
498 "Failed to convert value into readable tensor.");
499 return nullptr;
500 }
501
502 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
503 TfLiteTensor* tensor = interpreter_->subgraph(subgraph_index)->tensor(i);
504
505 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
506 PyErr_Format(PyExc_ValueError,
507 "Cannot set tensor:"
508 " Got value of type %s"
509 " but expected type %s for input %d, name: %s ",
510 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
511 TfLiteTypeGetName(tensor->type), i, tensor->name);
512 return nullptr;
513 }
514
515 if (PyArray_NDIM(array) != tensor->dims->size) {
516 PyErr_Format(PyExc_ValueError,
517 "Cannot set tensor: Dimension mismatch."
518 " Got %d"
519 " but expected %d for input %d.",
520 PyArray_NDIM(array), tensor->dims->size, i);
521 return nullptr;
522 }
523
524 for (int j = 0; j < PyArray_NDIM(array); j++) {
525 if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
526 PyErr_Format(PyExc_ValueError,
527 "Cannot set tensor: Dimension mismatch."
528 " Got %ld"
529 " but expected %d for dimension %d of input %d.",
530 PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
531 return nullptr;
532 }
533 }
534
535 if (tensor->type != kTfLiteString) {
536 // Only allow empty tensors.
537 if (tensor->data.raw == nullptr && tensor->bytes) {
538 PyErr_Format(PyExc_ValueError,
539 "Cannot set tensor:"
540 " Tensor is unallocated. Try calling allocate_tensors()"
541 " first");
542 return nullptr;
543 }
544
545 size_t size = PyArray_NBYTES(array);
546 if (size != tensor->bytes) {
547 PyErr_Format(PyExc_ValueError,
548 "numpy array had %zu bytes but expected %zu bytes.", size,
549 tensor->bytes);
550 return nullptr;
551 }
552 memcpy(tensor->data.raw, PyArray_DATA(array), size);
553 } else {
554 DynamicBuffer dynamic_buffer;
555 if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
556 return nullptr;
557 }
558 dynamic_buffer.WriteToTensor(tensor, nullptr);
559 }
560 Py_RETURN_NONE;
561}
562
563int InterpreterWrapper::NumNodes() const {
564 if (!interpreter_) {
565 return 0;
566 }
567 return interpreter_->nodes_size();
568}
569
570PyObject* InterpreterWrapper::NodeInputs(int i) const {
571 TFLITE_PY_ENSURE_VALID_INTERPRETER();
572 TFLITE_PY_NODES_BOUNDS_CHECK(i);
573
574 const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
575 PyObject* inputs =
576 PyArrayFromIntVector(node->inputs->data, node->inputs->size);
577 return inputs;
578}
579
580PyObject* InterpreterWrapper::NodeOutputs(int i) const {
581 TFLITE_PY_ENSURE_VALID_INTERPRETER();
582 TFLITE_PY_NODES_BOUNDS_CHECK(i);
583
584 const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
585 PyObject* outputs =
586 PyArrayFromIntVector(node->outputs->data, node->outputs->size);
587 return outputs;
588}
589
590std::string InterpreterWrapper::NodeName(int i) const {
591 if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) {
592 return "";
593 }
594 // Get op name from registration
595 const TfLiteRegistration* node_registration =
596 &(interpreter_->node_and_registration(i)->second);
597 int32_t op_code = node_registration->builtin_code;
598 std::string op_name;
599 if (op_code == tflite::BuiltinOperator_CUSTOM) {
600 const char* custom_name = node_registration->custom_name;
601 op_name = custom_name ? custom_name : "UnknownCustomOp";
602 } else {
603 op_name = tflite::EnumNamesBuiltinOperator()[op_code];
604 }
605 std::string op_name_str(op_name);
606 return op_name_str;
607}
608
609namespace {
610
611// Checks to see if a tensor access can succeed (returns nullptr on error).
612// Otherwise returns Py_None.
613PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
614 TfLiteTensor** tensor, int* type_num,
615 int subgraph_index) {
616 TFLITE_PY_ENSURE_VALID_INTERPRETER();
617 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
618 TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(tensor_index, subgraph_index);
619
620 *tensor = interpreter_->subgraph(subgraph_index)->tensor(tensor_index);
621 // Invalid size only when bytes are 0 but pointer is allocated.
622 if ((*tensor)->bytes == 0 && (*tensor)->data.raw) {
623 PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
624 return nullptr;
625 }
626
627 *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
628 if (*type_num == -1) {
629 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
630 return nullptr;
631 }
632
633 // Tensor data can't be null if size is > 0. 0 bytes is valid if tensor
634 // is empty.
635 if (!(*tensor)->data.raw && (*tensor)->bytes) {
636 PyErr_SetString(PyExc_ValueError,
637 "Tensor data is null."
638 " Run allocate_tensors() first");
639 return nullptr;
640 }
641
642 Py_RETURN_NONE;
643}
644
645} // namespace
646
647PyObject* InterpreterWrapper::GetSignatureDefs() const {
648 PyObject* result = PyDict_New();
649 for (const auto& sig_key : interpreter_->signature_keys()) {
650 PyObject* signature_def = PyDict_New();
651 PyObject* inputs = PyDict_New();
652 PyObject* outputs = PyDict_New();
653 const auto& signature_def_inputs =
654 interpreter_->signature_inputs(sig_key->c_str());
655 const auto& signature_def_outputs =
656 interpreter_->signature_outputs(sig_key->c_str());
657 for (const auto& input : signature_def_inputs) {
658 PyDict_SetItemString(inputs, input.first.c_str(),
659 PyLong_FromLong(input.second));
660 }
661 for (const auto& output : signature_def_outputs) {
662 PyDict_SetItemString(outputs, output.first.c_str(),
663 PyLong_FromLong(output.second));
664 }
665
666 PyDict_SetItemString(signature_def, "inputs", inputs);
667 PyDict_SetItemString(signature_def, "outputs", outputs);
668 PyDict_SetItemString(result, sig_key->c_str(), signature_def);
669 }
670 return result;
671}
672
673PyObject* InterpreterWrapper::GetSubgraphIndexFromSignature(
674 const char* signature_key) {
675 TFLITE_PY_ENSURE_VALID_INTERPRETER();
676
677 int32_t subgraph_index =
678 interpreter_->GetSubgraphIndexFromSignature(signature_key);
679
680 if (subgraph_index < 0) {
681 PyErr_SetString(PyExc_ValueError, "No matching signature.");
682 return nullptr;
683 }
684 return PyLong_FromLong(static_cast<int64_t>(subgraph_index));
685}
686
687PyObject* InterpreterWrapper::GetTensor(int i, int subgraph_index) const {
688 // Sanity check accessor
689 TfLiteTensor* tensor = nullptr;
690 int type_num = 0;
691
692 PyObject* check_result = CheckGetTensorArgs(interpreter_.get(), i, &tensor,
693 &type_num, subgraph_index);
694 if (check_result == nullptr) return check_result;
695 Py_XDECREF(check_result);
696
697 std::vector<npy_intp> dims(tensor->dims->data,
698 tensor->dims->data + tensor->dims->size);
699 if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
700 tensor->type != kTfLiteVariant) {
701 // Make a buffer copy but we must tell Numpy It owns that data or else
702 // it will leak.
703 void* data = malloc(tensor->bytes);
704 if (!data) {
705 PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
706 return nullptr;
707 }
708 memcpy(data, tensor->data.raw, tensor->bytes);
709 PyObject* np_array;
710 if (tensor->sparsity == nullptr) {
711 np_array =
712 PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
713 } else {
714 std::vector<npy_intp> sparse_buffer_dims(1);
715 size_t size_of_type;
716 if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
717 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
718 free(data);
719 return nullptr;
720 }
721 sparse_buffer_dims[0] = tensor->bytes / size_of_type;
722 np_array = PyArray_SimpleNewFromData(
723 sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
724 }
725 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
726 NPY_ARRAY_OWNDATA);
727 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
728 } else {
729 // Create a C-order array so the data is contiguous in memory.
730 const int32_t kCOrder = 0;
731 PyObject* py_object =
732 PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
733
734 if (py_object == nullptr) {
735 PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
736 return nullptr;
737 }
738
739 PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
740 PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
741 auto num_strings = GetStringCount(tensor);
742 for (int j = 0; j < num_strings; ++j) {
743 auto ref = GetString(tensor, j);
744
745 PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
746 if (bytes == nullptr) {
747 Py_DECREF(py_object);
748 PyErr_Format(PyExc_ValueError,
749 "Could not create PyBytes from string %d of input %d.", j,
750 i);
751 return nullptr;
752 }
753 // PyArray_EMPTY produces an array full of Py_None, which we must decref.
754 Py_DECREF(data[j]);
755 data[j] = bytes;
756 }
757 return py_object;
758 }
759}
760
761PyObject* InterpreterWrapper::tensor(PyObject* base_object, int tensor_index,
762 int subgraph_index) {
763 // Sanity check accessor
764 TfLiteTensor* tensor = nullptr;
765 int type_num = 0;
766
767 PyObject* check_result = CheckGetTensorArgs(
768 interpreter_.get(), tensor_index, &tensor, &type_num, subgraph_index);
769 if (check_result == nullptr) return check_result;
770 Py_XDECREF(check_result);
771
772 std::vector<npy_intp> dims(tensor->dims->data,
773 tensor->dims->data + tensor->dims->size);
774 PyArrayObject* np_array =
775 reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
776 dims.size(), dims.data(), type_num, tensor->data.raw));
777 Py_INCREF(base_object); // SetBaseObject steals, so we need to add.
778 PyArray_SetBaseObject(np_array, base_object);
779 return PyArray_Return(np_array);
780}
781
782InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
783 const char* model_path, int op_resolver_id,
784 const std::vector<std::string>& registerers_by_name,
785 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
786 std::string* error_msg, bool preserve_all_tensors) {
787 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
788 std::unique_ptr<InterpreterWrapper::Model> model =
789 Model::BuildFromFile(model_path, error_reporter.get());
790 return CreateInterpreterWrapper(std::move(model), op_resolver_id,
791 std::move(error_reporter),
792 registerers_by_name, registerers_by_func,
793 error_msg, preserve_all_tensors);
794}
795
796InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
797 const char* model_path, int op_resolver_id,
798 const std::vector<std::string>& registerers, std::string* error_msg,
799 bool preserve_all_tensors) {
800 return CreateWrapperCPPFromFile(model_path, op_resolver_id, registerers,
801 {} /*registerers_by_func*/, error_msg,
802 preserve_all_tensors);
803}
804
805InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
806 PyObject* data, int op_resolver_id,
807 const std::vector<std::string>& registerers_by_name,
808 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
809 std::string* error_msg, bool preserve_all_tensors) {
810 char* buf = nullptr;
811 Py_ssize_t length;
812 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
813
814 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
815 return nullptr;
816 }
817 std::unique_ptr<InterpreterWrapper::Model> model =
818 Model::BuildFromBuffer(buf, length, error_reporter.get());
819 return CreateInterpreterWrapper(std::move(model), op_resolver_id,
820 std::move(error_reporter),
821 registerers_by_name, registerers_by_func,
822 error_msg, preserve_all_tensors);
823}
824
825InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
826 PyObject* data, int op_resolver_id,
827 const std::vector<std::string>& registerers, std::string* error_msg,
828 bool preserve_all_tensors) {
829 return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {},
830 error_msg, preserve_all_tensors);
831}
832
833PyObject* InterpreterWrapper::ResetVariableTensors() {
834 TFLITE_PY_ENSURE_VALID_INTERPRETER();
835 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
836 Py_RETURN_NONE;
837}
838
839PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
840 TFLITE_PY_ENSURE_VALID_INTERPRETER();
841 interpreter_->SetNumThreads(num_threads);
842 Py_RETURN_NONE;
843}
844
845PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
846 TfLiteDelegate* delegate) {
847 TFLITE_PY_ENSURE_VALID_INTERPRETER();
848 TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
849 Py_RETURN_NONE;
850}
851
852} // namespace interpreter_wrapper
853} // namespace tflite
854