1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" |
16 | |
17 | #include <stdarg.h> |
18 | |
19 | #include <cstring> |
20 | #include <functional> |
21 | #include <memory> |
22 | #include <sstream> |
23 | #include <string> |
24 | #include <utility> |
25 | |
26 | #include "absl/memory/memory.h" |
27 | #include "absl/strings/str_format.h" |
28 | #include "tensorflow/lite/c/common.h" |
29 | #include "tensorflow/lite/core/api/error_reporter.h" |
30 | #include "tensorflow/lite/core/api/op_resolver.h" |
31 | #include "tensorflow/lite/interpreter.h" |
32 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
33 | #include "tensorflow/lite/kernels/register.h" |
34 | #include "tensorflow/lite/kernels/register_ref.h" |
35 | #include "tensorflow/lite/model.h" |
36 | #include "tensorflow/lite/mutable_op_resolver.h" |
37 | #include "tensorflow/lite/python/interpreter_wrapper/numpy.h" |
38 | #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" |
39 | #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" |
40 | #include "tensorflow/lite/shared_library.h" |
41 | #include "tensorflow/lite/string_util.h" |
42 | #include "tensorflow/lite/util.h" |
43 | |
44 | #define TFLITE_PY_CHECK(x) \ |
45 | if ((x) != kTfLiteOk) { \ |
46 | return error_reporter_->exception(); \ |
47 | } |
48 | |
49 | #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \ |
50 | if (i >= interpreter_->tensors_size() || i < 0) { \ |
51 | PyErr_Format(PyExc_ValueError, \ |
52 | "Invalid tensor index %d exceeds max tensor index %lu", i, \ |
53 | interpreter_->tensors_size()); \ |
54 | return nullptr; \ |
55 | } |
56 | |
57 | #define TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index) \ |
58 | if (i >= interpreter_->subgraph(subgraph_index)->tensors_size() || i < 0) { \ |
59 | PyErr_Format(PyExc_ValueError, \ |
60 | "Invalid tensor index %d exceeds max tensor index %lu", i, \ |
61 | interpreter_->subgraph(subgraph_index)->tensors_size()); \ |
62 | return nullptr; \ |
63 | } |
64 | |
65 | #define TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(i) \ |
66 | if (i >= interpreter_->subgraphs_size() || i < 0) { \ |
67 | PyErr_Format(PyExc_ValueError, \ |
68 | "Invalid subgraph index %d exceeds max subgraph index %lu", \ |
69 | i, interpreter_->subgraphs_size()); \ |
70 | return nullptr; \ |
71 | } |
72 | |
73 | #define TFLITE_PY_NODES_BOUNDS_CHECK(i) \ |
74 | if (i >= interpreter_->nodes_size() || i < 0) { \ |
75 | PyErr_Format(PyExc_ValueError, "Invalid node index"); \ |
76 | return nullptr; \ |
77 | } |
78 | |
79 | #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \ |
80 | if (!interpreter_) { \ |
81 | PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \ |
82 | return nullptr; \ |
83 | } |
84 | |
85 | namespace tflite { |
86 | namespace interpreter_wrapper { |
87 | |
88 | namespace { |
89 | |
90 | using python_utils::PyDecrefDeleter; |
91 | |
92 | std::unique_ptr<Interpreter> CreateInterpreter( |
93 | const InterpreterWrapper::Model* model, |
94 | const tflite::MutableOpResolver& resolver, bool preserve_all_tensors) { |
95 | if (!model) { |
96 | return nullptr; |
97 | } |
98 | |
99 | ::tflite::python::ImportNumpy(); |
100 | |
101 | std::unique_ptr<Interpreter> interpreter; |
102 | InterpreterOptions options; |
103 | options.SetPreserveAllTensors(preserve_all_tensors); |
104 | InterpreterBuilder builder(*model, resolver, &options); |
105 | if (builder(&interpreter) != kTfLiteOk) { |
106 | return nullptr; |
107 | } |
108 | return interpreter; |
109 | } |
110 | |
111 | PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) { |
112 | void* pydata = malloc(size * sizeof(float)); |
113 | memcpy(pydata, data, size * sizeof(float)); |
114 | PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata); |
115 | PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA); |
116 | return obj; |
117 | } |
118 | |
119 | PyObject* PyArrayFromIntVector(const int* data, npy_intp size) { |
120 | void* pydata = malloc(size * sizeof(int)); |
121 | memcpy(pydata, data, size * sizeof(int)); |
122 | PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata); |
123 | PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA); |
124 | return obj; |
125 | } |
126 | |
127 | PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { |
128 | PyObject* result = PyTuple_New(2); |
129 | PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale)); |
130 | PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point)); |
131 | return result; |
132 | } |
133 | |
134 | PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) { |
135 | PyObject* result = PyDict_New(); |
136 | PyDict_SetItemString(result, "traversal_order" , |
137 | PyArrayFromIntVector(param.traversal_order->data, |
138 | param.traversal_order->size)); |
139 | PyDict_SetItemString( |
140 | result, "block_map" , |
141 | PyArrayFromIntVector(param.block_map->data, param.block_map->size)); |
142 | PyObject* dim_metadata = PyList_New(param.dim_metadata_size); |
143 | for (int i = 0; i < param.dim_metadata_size; i++) { |
144 | PyObject* dim_metadata_i = PyDict_New(); |
145 | if (param.dim_metadata[i].format == kTfLiteDimDense) { |
146 | PyDict_SetItemString(dim_metadata_i, "format" , PyLong_FromSize_t(0)); |
147 | PyDict_SetItemString(dim_metadata_i, "dense_size" , |
148 | PyLong_FromSize_t(param.dim_metadata[i].dense_size)); |
149 | } else { |
150 | PyDict_SetItemString(dim_metadata_i, "format" , PyLong_FromSize_t(1)); |
151 | const auto* array_segments = param.dim_metadata[i].array_segments; |
152 | const auto* array_indices = param.dim_metadata[i].array_indices; |
153 | PyDict_SetItemString( |
154 | dim_metadata_i, "array_segments" , |
155 | PyArrayFromIntVector(array_segments->data, array_segments->size)); |
156 | PyDict_SetItemString( |
157 | dim_metadata_i, "array_indices" , |
158 | PyArrayFromIntVector(array_indices->data, array_indices->size)); |
159 | } |
160 | PyList_SetItem(dim_metadata, i, dim_metadata_i); |
161 | } |
162 | PyDict_SetItemString(result, "dim_metadata" , dim_metadata); |
163 | return result; |
164 | } |
165 | |
166 | bool RegisterCustomOpByName(const char* registerer_name, |
167 | tflite::MutableOpResolver* resolver, |
168 | std::string* error_msg) { |
169 | // Registerer functions take a pointer to a BuiltinOpResolver as an input |
170 | // parameter and return void. |
171 | // TODO(b/137576229): We should implement this functionality in a more |
172 | // principled way. |
173 | typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*); |
174 | |
175 | // Look for the Registerer function by name. |
176 | RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>( |
177 | SharedLibrary::GetSymbol(registerer_name)); |
178 | |
179 | // Fail in an informative way if the function was not found. |
180 | if (registerer == nullptr) { |
181 | *error_msg = |
182 | absl::StrFormat("Looking up symbol '%s' failed with error '%s'." , |
183 | registerer_name, SharedLibrary::GetError()); |
184 | return false; |
185 | } |
186 | |
187 | // Call the registerer with the resolver. |
188 | registerer(resolver); |
189 | return true; |
190 | } |
191 | |
192 | } // namespace |
193 | |
194 | static constexpr int kBuiltinOpResolver = 1; |
195 | static constexpr int kBuiltinRefOpResolver = 2; |
196 | static constexpr int kBuiltinOpResolverWithoutDefaultDelegates = 3; |
197 | |
198 | InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( |
199 | std::unique_ptr<InterpreterWrapper::Model> model, int op_resolver_id, |
200 | std::unique_ptr<PythonErrorReporter> error_reporter, |
201 | const std::vector<std::string>& registerers_by_name, |
202 | const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, |
203 | std::string* error_msg, bool preserve_all_tensors) { |
204 | if (!model) { |
205 | *error_msg = error_reporter->message(); |
206 | return nullptr; |
207 | } |
208 | |
209 | std::unique_ptr<tflite::MutableOpResolver> resolver; |
210 | switch (op_resolver_id) { |
211 | case kBuiltinOpResolver: |
212 | resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); |
213 | break; |
214 | case kBuiltinRefOpResolver: |
215 | resolver = std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>(); |
216 | break; |
217 | case kBuiltinOpResolverWithoutDefaultDelegates: |
218 | resolver = std::make_unique< |
219 | tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>(); |
220 | break; |
221 | default: |
222 | // This should not never happen because the eventual caller in |
223 | // interpreter.py should have passed a valid id here. |
224 | TFLITE_DCHECK(false); |
225 | return nullptr; |
226 | } |
227 | |
228 | for (const auto& registerer : registerers_by_name) { |
229 | if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg)) |
230 | return nullptr; |
231 | } |
232 | for (const auto& registerer : registerers_by_func) { |
233 | registerer(reinterpret_cast<uintptr_t>(resolver.get())); |
234 | } |
235 | auto interpreter = |
236 | CreateInterpreter(model.get(), *resolver, preserve_all_tensors); |
237 | if (!interpreter) { |
238 | *error_msg = error_reporter->message(); |
239 | return nullptr; |
240 | } |
241 | |
242 | InterpreterWrapper* wrapper = |
243 | new InterpreterWrapper(std::move(model), std::move(error_reporter), |
244 | std::move(resolver), std::move(interpreter)); |
245 | return wrapper; |
246 | } |
247 | |
248 | InterpreterWrapper::InterpreterWrapper( |
249 | std::unique_ptr<InterpreterWrapper::Model> model, |
250 | std::unique_ptr<PythonErrorReporter> error_reporter, |
251 | std::unique_ptr<tflite::MutableOpResolver> resolver, |
252 | std::unique_ptr<Interpreter> interpreter) |
253 | : model_(std::move(model)), |
254 | error_reporter_(std::move(error_reporter)), |
255 | resolver_(std::move(resolver)), |
256 | interpreter_(std::move(interpreter)) {} |
257 | |
258 | InterpreterWrapper::~InterpreterWrapper() {} |
259 | |
260 | // LINT.IfChange |
261 | static constexpr int kUndeterminedSubgraphIndex = -1; |
262 | // LINT.ThenChange(//tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc) |
263 | PyObject* InterpreterWrapper::AllocateTensors(int subgraph_index) { |
264 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
265 | if (subgraph_index == kUndeterminedSubgraphIndex) { |
266 | TFLITE_PY_CHECK(interpreter_->AllocateTensors()); |
267 | } else { |
268 | TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index); |
269 | TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)->AllocateTensors()); |
270 | } |
271 | Py_RETURN_NONE; |
272 | } |
273 | |
274 | PyObject* InterpreterWrapper::Invoke(int subgraph_index) { |
275 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
276 | TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index); |
277 | |
278 | // Release the GIL so that we can run multiple interpreters in parallel |
279 | TfLiteStatus status_code = kTfLiteOk; |
280 | Py_BEGIN_ALLOW_THREADS; // To return can happen between this and end! |
281 | tflite::Subgraph* subgraph = interpreter_->subgraph(subgraph_index); |
282 | status_code = subgraph->Invoke(); |
283 | |
284 | if (!interpreter_->allow_buffer_handle_output_) { |
285 | for (int tensor_index : subgraph->outputs()) { |
286 | subgraph->EnsureTensorDataIsReadable(tensor_index); |
287 | } |
288 | } |
289 | Py_END_ALLOW_THREADS; |
290 | |
291 | TFLITE_PY_CHECK( |
292 | status_code); // don't move this into the Py_BEGIN/Py_End block |
293 | |
294 | Py_RETURN_NONE; |
295 | } |
296 | |
297 | PyObject* InterpreterWrapper::InputIndices() const { |
298 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
299 | PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(), |
300 | interpreter_->inputs().size()); |
301 | |
302 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); |
303 | } |
304 | |
305 | PyObject* InterpreterWrapper::OutputIndices() const { |
306 | PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(), |
307 | interpreter_->outputs().size()); |
308 | |
309 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); |
310 | } |
311 | |
312 | PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) { |
313 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
314 | |
315 | std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( |
316 | PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); |
317 | if (!array_safe) { |
318 | PyErr_SetString(PyExc_ValueError, |
319 | "Failed to convert numpy value into readable tensor." ); |
320 | return nullptr; |
321 | } |
322 | |
323 | PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); |
324 | |
325 | if (PyArray_NDIM(array) != 1) { |
326 | PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d." , |
327 | PyArray_NDIM(array)); |
328 | return nullptr; |
329 | } |
330 | |
331 | if (PyArray_TYPE(array) != NPY_INT32) { |
332 | PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d)." , |
333 | PyArray_TYPE(array)); |
334 | return nullptr; |
335 | } |
336 | |
337 | PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array), |
338 | NPY_ARRAY_OWNDATA); |
339 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(array)); |
340 | } |
341 | |
342 | PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value, |
343 | bool strict, |
344 | int subgraph_index) { |
345 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
346 | TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index); |
347 | |
348 | PyArrayObject* array = |
349 | reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value)); |
350 | if (array == nullptr) { |
351 | return nullptr; |
352 | } |
353 | |
354 | std::vector<int> dims(PyArray_SHAPE(array)[0]); |
355 | memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); |
356 | |
357 | if (strict) { |
358 | TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index) |
359 | ->ResizeInputTensorStrict(i, dims)); |
360 | } else { |
361 | TFLITE_PY_CHECK( |
362 | interpreter_->subgraph(subgraph_index)->ResizeInputTensor(i, dims)); |
363 | } |
364 | Py_RETURN_NONE; |
365 | } |
366 | |
367 | int InterpreterWrapper::NumTensors() const { |
368 | if (!interpreter_) { |
369 | return 0; |
370 | } |
371 | return interpreter_->tensors_size(); |
372 | } |
373 | |
374 | std::string InterpreterWrapper::TensorName(int i) const { |
375 | if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { |
376 | return "" ; |
377 | } |
378 | |
379 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
380 | return tensor->name ? tensor->name : "" ; |
381 | } |
382 | |
383 | PyObject* InterpreterWrapper::TensorType(int i) const { |
384 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
385 | TFLITE_PY_TENSOR_BOUNDS_CHECK(i); |
386 | |
387 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
388 | if (tensor->type == kTfLiteNoType) { |
389 | PyErr_Format(PyExc_ValueError, "Tensor with no type found." ); |
390 | return nullptr; |
391 | } |
392 | |
393 | int code = python_utils::TfLiteTypeToPyArrayType(tensor->type); |
394 | if (code == -1) { |
395 | PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d" , code); |
396 | return nullptr; |
397 | } |
398 | return PyArray_TypeObjectFromType(code); |
399 | } |
400 | |
401 | PyObject* InterpreterWrapper::TensorSize(int i) const { |
402 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
403 | TFLITE_PY_TENSOR_BOUNDS_CHECK(i); |
404 | |
405 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
406 | if (tensor->dims == nullptr) { |
407 | PyErr_Format(PyExc_ValueError, "Tensor with no shape found." ); |
408 | return nullptr; |
409 | } |
410 | PyObject* np_array = |
411 | PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); |
412 | |
413 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); |
414 | } |
415 | |
416 | PyObject* InterpreterWrapper::TensorSizeSignature(int i) const { |
417 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
418 | TFLITE_PY_TENSOR_BOUNDS_CHECK(i); |
419 | |
420 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
421 | const int32_t* size_signature_data = nullptr; |
422 | int32_t size_signature_size = 0; |
423 | if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) { |
424 | size_signature_data = tensor->dims_signature->data; |
425 | size_signature_size = tensor->dims_signature->size; |
426 | } else { |
427 | size_signature_data = tensor->dims->data; |
428 | size_signature_size = tensor->dims->size; |
429 | } |
430 | PyObject* np_array = |
431 | PyArrayFromIntVector(size_signature_data, size_signature_size); |
432 | |
433 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); |
434 | } |
435 | |
436 | PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const { |
437 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
438 | TFLITE_PY_TENSOR_BOUNDS_CHECK(i); |
439 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
440 | if (tensor->sparsity == nullptr) { |
441 | return PyDict_New(); |
442 | } |
443 | |
444 | return PyDictFromSparsityParam(*tensor->sparsity); |
445 | } |
446 | |
447 | PyObject* InterpreterWrapper::TensorQuantization(int i) const { |
448 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
449 | TFLITE_PY_TENSOR_BOUNDS_CHECK(i); |
450 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
451 | return PyTupleFromQuantizationParam(tensor->params); |
452 | } |
453 | |
454 | PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const { |
455 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
456 | TFLITE_PY_TENSOR_BOUNDS_CHECK(i); |
457 | const TfLiteTensor* tensor = interpreter_->tensor(i); |
458 | const TfLiteQuantization quantization = tensor->quantization; |
459 | float* scales_data = nullptr; |
460 | int32_t* zero_points_data = nullptr; |
461 | int32_t scales_size = 0; |
462 | int32_t zero_points_size = 0; |
463 | int32_t quantized_dimension = 0; |
464 | if (quantization.type == kTfLiteAffineQuantization) { |
465 | const TfLiteAffineQuantization* q_params = |
466 | reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params); |
467 | if (q_params->scale) { |
468 | scales_data = q_params->scale->data; |
469 | scales_size = q_params->scale->size; |
470 | } |
471 | if (q_params->zero_point) { |
472 | zero_points_data = q_params->zero_point->data; |
473 | zero_points_size = q_params->zero_point->size; |
474 | } |
475 | quantized_dimension = q_params->quantized_dimension; |
476 | } |
477 | PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size); |
478 | PyObject* zero_points_array = |
479 | PyArrayFromIntVector(zero_points_data, zero_points_size); |
480 | |
481 | PyObject* result = PyTuple_New(3); |
482 | PyTuple_SET_ITEM(result, 0, scales_array); |
483 | PyTuple_SET_ITEM(result, 1, zero_points_array); |
484 | PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension)); |
485 | return result; |
486 | } |
487 | |
488 | PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value, |
489 | int subgraph_index) { |
490 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
491 | TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index); |
492 | TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index); |
493 | |
494 | std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( |
495 | PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); |
496 | if (!array_safe) { |
497 | PyErr_SetString(PyExc_ValueError, |
498 | "Failed to convert value into readable tensor." ); |
499 | return nullptr; |
500 | } |
501 | |
502 | PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); |
503 | TfLiteTensor* tensor = interpreter_->subgraph(subgraph_index)->tensor(i); |
504 | |
505 | if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) { |
506 | PyErr_Format(PyExc_ValueError, |
507 | "Cannot set tensor:" |
508 | " Got value of type %s" |
509 | " but expected type %s for input %d, name: %s " , |
510 | TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)), |
511 | TfLiteTypeGetName(tensor->type), i, tensor->name); |
512 | return nullptr; |
513 | } |
514 | |
515 | if (PyArray_NDIM(array) != tensor->dims->size) { |
516 | PyErr_Format(PyExc_ValueError, |
517 | "Cannot set tensor: Dimension mismatch." |
518 | " Got %d" |
519 | " but expected %d for input %d." , |
520 | PyArray_NDIM(array), tensor->dims->size, i); |
521 | return nullptr; |
522 | } |
523 | |
524 | for (int j = 0; j < PyArray_NDIM(array); j++) { |
525 | if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { |
526 | PyErr_Format(PyExc_ValueError, |
527 | "Cannot set tensor: Dimension mismatch." |
528 | " Got %ld" |
529 | " but expected %d for dimension %d of input %d." , |
530 | PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i); |
531 | return nullptr; |
532 | } |
533 | } |
534 | |
535 | if (tensor->type != kTfLiteString) { |
536 | // Only allow empty tensors. |
537 | if (tensor->data.raw == nullptr && tensor->bytes) { |
538 | PyErr_Format(PyExc_ValueError, |
539 | "Cannot set tensor:" |
540 | " Tensor is unallocated. Try calling allocate_tensors()" |
541 | " first" ); |
542 | return nullptr; |
543 | } |
544 | |
545 | size_t size = PyArray_NBYTES(array); |
546 | if (size != tensor->bytes) { |
547 | PyErr_Format(PyExc_ValueError, |
548 | "numpy array had %zu bytes but expected %zu bytes." , size, |
549 | tensor->bytes); |
550 | return nullptr; |
551 | } |
552 | memcpy(tensor->data.raw, PyArray_DATA(array), size); |
553 | } else { |
554 | DynamicBuffer dynamic_buffer; |
555 | if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) { |
556 | return nullptr; |
557 | } |
558 | dynamic_buffer.WriteToTensor(tensor, nullptr); |
559 | } |
560 | Py_RETURN_NONE; |
561 | } |
562 | |
563 | int InterpreterWrapper::NumNodes() const { |
564 | if (!interpreter_) { |
565 | return 0; |
566 | } |
567 | return interpreter_->nodes_size(); |
568 | } |
569 | |
570 | PyObject* InterpreterWrapper::NodeInputs(int i) const { |
571 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
572 | TFLITE_PY_NODES_BOUNDS_CHECK(i); |
573 | |
574 | const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first); |
575 | PyObject* inputs = |
576 | PyArrayFromIntVector(node->inputs->data, node->inputs->size); |
577 | return inputs; |
578 | } |
579 | |
580 | PyObject* InterpreterWrapper::NodeOutputs(int i) const { |
581 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
582 | TFLITE_PY_NODES_BOUNDS_CHECK(i); |
583 | |
584 | const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first); |
585 | PyObject* outputs = |
586 | PyArrayFromIntVector(node->outputs->data, node->outputs->size); |
587 | return outputs; |
588 | } |
589 | |
590 | std::string InterpreterWrapper::NodeName(int i) const { |
591 | if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) { |
592 | return "" ; |
593 | } |
594 | // Get op name from registration |
595 | const TfLiteRegistration* node_registration = |
596 | &(interpreter_->node_and_registration(i)->second); |
597 | int32_t op_code = node_registration->builtin_code; |
598 | std::string op_name; |
599 | if (op_code == tflite::BuiltinOperator_CUSTOM) { |
600 | const char* custom_name = node_registration->custom_name; |
601 | op_name = custom_name ? custom_name : "UnknownCustomOp" ; |
602 | } else { |
603 | op_name = tflite::EnumNamesBuiltinOperator()[op_code]; |
604 | } |
605 | std::string op_name_str(op_name); |
606 | return op_name_str; |
607 | } |
608 | |
609 | namespace { |
610 | |
611 | // Checks to see if a tensor access can succeed (returns nullptr on error). |
612 | // Otherwise returns Py_None. |
613 | PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, |
614 | TfLiteTensor** tensor, int* type_num, |
615 | int subgraph_index) { |
616 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
617 | TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index); |
618 | TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(tensor_index, subgraph_index); |
619 | |
620 | *tensor = interpreter_->subgraph(subgraph_index)->tensor(tensor_index); |
621 | // Invalid size only when bytes are 0 but pointer is allocated. |
622 | if ((*tensor)->bytes == 0 && (*tensor)->data.raw) { |
623 | PyErr_SetString(PyExc_ValueError, "Invalid tensor size." ); |
624 | return nullptr; |
625 | } |
626 | |
627 | *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type); |
628 | if (*type_num == -1) { |
629 | PyErr_SetString(PyExc_ValueError, "Unknown tensor type." ); |
630 | return nullptr; |
631 | } |
632 | |
633 | // Tensor data can't be null if size is > 0. 0 bytes is valid if tensor |
634 | // is empty. |
635 | if (!(*tensor)->data.raw && (*tensor)->bytes) { |
636 | PyErr_SetString(PyExc_ValueError, |
637 | "Tensor data is null." |
638 | " Run allocate_tensors() first" ); |
639 | return nullptr; |
640 | } |
641 | |
642 | Py_RETURN_NONE; |
643 | } |
644 | |
645 | } // namespace |
646 | |
647 | PyObject* InterpreterWrapper::GetSignatureDefs() const { |
648 | PyObject* result = PyDict_New(); |
649 | for (const auto& sig_key : interpreter_->signature_keys()) { |
650 | PyObject* signature_def = PyDict_New(); |
651 | PyObject* inputs = PyDict_New(); |
652 | PyObject* outputs = PyDict_New(); |
653 | const auto& signature_def_inputs = |
654 | interpreter_->signature_inputs(sig_key->c_str()); |
655 | const auto& signature_def_outputs = |
656 | interpreter_->signature_outputs(sig_key->c_str()); |
657 | for (const auto& input : signature_def_inputs) { |
658 | PyDict_SetItemString(inputs, input.first.c_str(), |
659 | PyLong_FromLong(input.second)); |
660 | } |
661 | for (const auto& output : signature_def_outputs) { |
662 | PyDict_SetItemString(outputs, output.first.c_str(), |
663 | PyLong_FromLong(output.second)); |
664 | } |
665 | |
666 | PyDict_SetItemString(signature_def, "inputs" , inputs); |
667 | PyDict_SetItemString(signature_def, "outputs" , outputs); |
668 | PyDict_SetItemString(result, sig_key->c_str(), signature_def); |
669 | } |
670 | return result; |
671 | } |
672 | |
673 | PyObject* InterpreterWrapper::GetSubgraphIndexFromSignature( |
674 | const char* signature_key) { |
675 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
676 | |
677 | int32_t subgraph_index = |
678 | interpreter_->GetSubgraphIndexFromSignature(signature_key); |
679 | |
680 | if (subgraph_index < 0) { |
681 | PyErr_SetString(PyExc_ValueError, "No matching signature." ); |
682 | return nullptr; |
683 | } |
684 | return PyLong_FromLong(static_cast<int64_t>(subgraph_index)); |
685 | } |
686 | |
687 | PyObject* InterpreterWrapper::GetTensor(int i, int subgraph_index) const { |
688 | // Sanity check accessor |
689 | TfLiteTensor* tensor = nullptr; |
690 | int type_num = 0; |
691 | |
692 | PyObject* check_result = CheckGetTensorArgs(interpreter_.get(), i, &tensor, |
693 | &type_num, subgraph_index); |
694 | if (check_result == nullptr) return check_result; |
695 | Py_XDECREF(check_result); |
696 | |
697 | std::vector<npy_intp> dims(tensor->dims->data, |
698 | tensor->dims->data + tensor->dims->size); |
699 | if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource && |
700 | tensor->type != kTfLiteVariant) { |
701 | // Make a buffer copy but we must tell Numpy It owns that data or else |
702 | // it will leak. |
703 | void* data = malloc(tensor->bytes); |
704 | if (!data) { |
705 | PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed." ); |
706 | return nullptr; |
707 | } |
708 | memcpy(data, tensor->data.raw, tensor->bytes); |
709 | PyObject* np_array; |
710 | if (tensor->sparsity == nullptr) { |
711 | np_array = |
712 | PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); |
713 | } else { |
714 | std::vector<npy_intp> sparse_buffer_dims(1); |
715 | size_t size_of_type; |
716 | if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) { |
717 | PyErr_SetString(PyExc_ValueError, "Unknown tensor type." ); |
718 | free(data); |
719 | return nullptr; |
720 | } |
721 | sparse_buffer_dims[0] = tensor->bytes / size_of_type; |
722 | np_array = PyArray_SimpleNewFromData( |
723 | sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data); |
724 | } |
725 | PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array), |
726 | NPY_ARRAY_OWNDATA); |
727 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); |
728 | } else { |
729 | // Create a C-order array so the data is contiguous in memory. |
730 | const int32_t kCOrder = 0; |
731 | PyObject* py_object = |
732 | PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder); |
733 | |
734 | if (py_object == nullptr) { |
735 | PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray." ); |
736 | return nullptr; |
737 | } |
738 | |
739 | PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object); |
740 | PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array)); |
741 | auto num_strings = GetStringCount(tensor); |
742 | for (int j = 0; j < num_strings; ++j) { |
743 | auto ref = GetString(tensor, j); |
744 | |
745 | PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len); |
746 | if (bytes == nullptr) { |
747 | Py_DECREF(py_object); |
748 | PyErr_Format(PyExc_ValueError, |
749 | "Could not create PyBytes from string %d of input %d." , j, |
750 | i); |
751 | return nullptr; |
752 | } |
753 | // PyArray_EMPTY produces an array full of Py_None, which we must decref. |
754 | Py_DECREF(data[j]); |
755 | data[j] = bytes; |
756 | } |
757 | return py_object; |
758 | } |
759 | } |
760 | |
761 | PyObject* InterpreterWrapper::tensor(PyObject* base_object, int tensor_index, |
762 | int subgraph_index) { |
763 | // Sanity check accessor |
764 | TfLiteTensor* tensor = nullptr; |
765 | int type_num = 0; |
766 | |
767 | PyObject* check_result = CheckGetTensorArgs( |
768 | interpreter_.get(), tensor_index, &tensor, &type_num, subgraph_index); |
769 | if (check_result == nullptr) return check_result; |
770 | Py_XDECREF(check_result); |
771 | |
772 | std::vector<npy_intp> dims(tensor->dims->data, |
773 | tensor->dims->data + tensor->dims->size); |
774 | PyArrayObject* np_array = |
775 | reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData( |
776 | dims.size(), dims.data(), type_num, tensor->data.raw)); |
777 | Py_INCREF(base_object); // SetBaseObject steals, so we need to add. |
778 | PyArray_SetBaseObject(np_array, base_object); |
779 | return PyArray_Return(np_array); |
780 | } |
781 | |
782 | InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( |
783 | const char* model_path, int op_resolver_id, |
784 | const std::vector<std::string>& registerers_by_name, |
785 | const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, |
786 | std::string* error_msg, bool preserve_all_tensors) { |
787 | std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); |
788 | std::unique_ptr<InterpreterWrapper::Model> model = |
789 | Model::BuildFromFile(model_path, error_reporter.get()); |
790 | return CreateInterpreterWrapper(std::move(model), op_resolver_id, |
791 | std::move(error_reporter), |
792 | registerers_by_name, registerers_by_func, |
793 | error_msg, preserve_all_tensors); |
794 | } |
795 | |
796 | InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( |
797 | const char* model_path, int op_resolver_id, |
798 | const std::vector<std::string>& registerers, std::string* error_msg, |
799 | bool preserve_all_tensors) { |
800 | return CreateWrapperCPPFromFile(model_path, op_resolver_id, registerers, |
801 | {} /*registerers_by_func*/, error_msg, |
802 | preserve_all_tensors); |
803 | } |
804 | |
805 | InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( |
806 | PyObject* data, int op_resolver_id, |
807 | const std::vector<std::string>& registerers_by_name, |
808 | const std::vector<std::function<void(uintptr_t)>>& registerers_by_func, |
809 | std::string* error_msg, bool preserve_all_tensors) { |
810 | char* buf = nullptr; |
811 | Py_ssize_t length; |
812 | std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); |
813 | |
814 | if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) { |
815 | return nullptr; |
816 | } |
817 | std::unique_ptr<InterpreterWrapper::Model> model = |
818 | Model::BuildFromBuffer(buf, length, error_reporter.get()); |
819 | return CreateInterpreterWrapper(std::move(model), op_resolver_id, |
820 | std::move(error_reporter), |
821 | registerers_by_name, registerers_by_func, |
822 | error_msg, preserve_all_tensors); |
823 | } |
824 | |
825 | InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( |
826 | PyObject* data, int op_resolver_id, |
827 | const std::vector<std::string>& registerers, std::string* error_msg, |
828 | bool preserve_all_tensors) { |
829 | return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {}, |
830 | error_msg, preserve_all_tensors); |
831 | } |
832 | |
833 | PyObject* InterpreterWrapper::ResetVariableTensors() { |
834 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
835 | TFLITE_PY_CHECK(interpreter_->ResetVariableTensors()); |
836 | Py_RETURN_NONE; |
837 | } |
838 | |
839 | PyObject* InterpreterWrapper::SetNumThreads(int num_threads) { |
840 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
841 | interpreter_->SetNumThreads(num_threads); |
842 | Py_RETURN_NONE; |
843 | } |
844 | |
845 | PyObject* InterpreterWrapper::ModifyGraphWithDelegate( |
846 | TfLiteDelegate* delegate) { |
847 | TFLITE_PY_ENSURE_VALID_INTERPRETER(); |
848 | TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate)); |
849 | Py_RETURN_NONE; |
850 | } |
851 | |
852 | } // namespace interpreter_wrapper |
853 | } // namespace tflite |
854 | |