1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/tools/optimize/calibration/calibrator.h" |
16 | |
17 | #include <fstream> |
18 | #include <memory> |
19 | #include <string> |
20 | #include <unordered_map> |
21 | #include <unordered_set> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "absl/container/flat_hash_map.h" |
26 | #include "absl/memory/memory.h" |
27 | #include "tensorflow/lite/c/common.h" |
28 | #include "tensorflow/lite/core/api/error_reporter.h" |
29 | #include "tensorflow/lite/core/api/op_resolver.h" |
30 | #include "tensorflow/lite/interpreter.h" |
31 | #include "tensorflow/lite/kernels/kernel_util.h" |
32 | #include "tensorflow/lite/kernels/register.h" |
33 | #include "tensorflow/lite/minimal_logging.h" |
34 | #include "tensorflow/lite/model.h" |
35 | #include "tensorflow/lite/op_resolver.h" |
36 | #include "tensorflow/lite/schema/schema_generated.h" |
37 | #include "tensorflow/lite/schema/schema_utils.h" |
38 | #include "tensorflow/lite/stderr_reporter.h" |
39 | #include "tensorflow/lite/string_util.h" |
40 | #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h" |
41 | #include "tensorflow/lite/tools/optimize/calibration/calibration_common.h" |
42 | #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h" |
43 | #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h" |
44 | #include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h" |
45 | #include "tensorflow/lite/tools/optimize/calibration/logging_op.h" |
46 | #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h" |
47 | |
48 | namespace tflite { |
49 | namespace optimize { |
50 | namespace calibration { |
51 | |
52 | namespace { |
53 | |
54 | // Calibrator is used to hold information that can be accessed during kernel |
55 | // invocations. |
56 | // TfLite kernel invocations are C functions and cannot look at the global |
57 | // structure of the graph. Calibrator allows the kernel invoke functions to |
58 | // access the global structure of graph and know which node is currently being |
59 | // executed. This also allows us to write a simple kernel invoke wrapper |
60 | // (see LoggingEval) that can work for most builtin ops. |
61 | class Calibrator { |
62 | public: |
63 | Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>& |
64 | node_ptr_opinfo_map, |
65 | std::unique_ptr<LoggingOpResolver> logging_op_resolver, |
66 | ErrorReporter* error_reporter) |
67 | : node_ptr_opinfo_map_(node_ptr_opinfo_map), |
68 | logging_op_resolver_(std::move(logging_op_resolver)), |
69 | error_reporter_(error_reporter) { |
70 | logger_ = std::make_unique<Logger>(); |
71 | } |
72 | |
73 | // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|. |
74 | KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const; |
75 | |
76 | // Gets the instance of logger associated with the current context. |
77 | Logger* GetLogger() const { return logger_.get(); } |
78 | |
79 | // Gets the error reporter. |
80 | ErrorReporter* GetErrorReporter() const { return error_reporter_; } |
81 | |
82 | // Gets the operator information about the given TfLiteNode. |
83 | const OperatorInfo& GetOpInfo(const TfLiteNode* node) const { |
84 | return node_ptr_opinfo_map_.at(node); |
85 | } |
86 | |
87 | std::vector<const TfLiteNode*> GetNodesUnderCalibration() { |
88 | std::vector<const TfLiteNode*> nodes; |
89 | nodes.reserve(node_ptr_opinfo_map_.size()); |
90 | for (const auto& entry : node_ptr_opinfo_map_) { |
91 | nodes.push_back(entry.first); |
92 | } |
93 | return nodes; |
94 | } |
95 | |
96 | private: |
97 | std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_; |
98 | std::unique_ptr<LoggingOpResolver> logging_op_resolver_; |
99 | const std::unordered_map<int, OperatorInfo> index_opinfo_; |
100 | std::unique_ptr<Logger> logger_; |
101 | ErrorReporter* error_reporter_; |
102 | }; |
103 | |
104 | KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const { |
105 | auto op_info = node_ptr_opinfo_map_.at(node); |
106 | if (op_info.is_custom_op) { |
107 | return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(), |
108 | op_info.version); |
109 | } |
110 | return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code, |
111 | op_info.version); |
112 | } |
113 | |
114 | // A registry of |Calibrator| objects per |TfLiteContext|. |
115 | // This global registry is needed to access |Calibrator| objects in the kernel |
116 | // invoke functions i.e. |TfLiteRegistration.invoke|. |
117 | // Kernel invoke functions are C functions that have limited access to |
118 | // |TfLiteContext|. Kernel invoke functions don't have access to global state of |
119 | // graph. That means during a kernel invocation, the function cannot know which |
120 | // node it was invoked for. E.g. in case of a model with |Conv| op at two |
121 | // locations, there is no easy way for the Conv.invoke function to disambiguate |
122 | // the calls. |
123 | // |
124 | // For calibration we solve this problem by creating a map of calibrators |
125 | // per |TfLiteContext|. This map is |GlobalCalibrationRegistry|. |
126 | // |
127 | // This registry is then accessed using a global getter function: |
128 | // |GetCalibratorRegistry|. |
129 | // E.g. |
130 | // TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) { |
131 | // .... code .... |
132 | // auto registry = GetCalibratorRegistry(); |
133 | // auto calibrator = registry->GetCalibrator(context); |
134 | // ..... code .... |
135 | // } |
136 | // |
137 | // This way the kernel invoke functions can get the access to the Calibrator |
138 | // object associated with the |TfLiteContext|. |
139 | class GlobalCalibratorRegistry { |
140 | public: |
141 | // Get the |Calibrator| associated with given context, returns null if no |
142 | // calibrator is associated with the given context. |
143 | Calibrator* GetCalibrator(const TfLiteNode* node) const { |
144 | if (node_to_calibrator_.find(node) == node_to_calibrator_.cend()) { |
145 | return nullptr; |
146 | } |
147 | return node_to_calibrator_.at(node); |
148 | } |
149 | |
150 | // Removes the association between calibrator and context. |
151 | // Note: This deletes the calibrator as well. |
152 | void RemoveCalibrator(const TfLiteContext* context) { |
153 | Calibrator* calibrator = calibrator_registry_.at(context).get(); |
154 | auto nodes = calibrator->GetNodesUnderCalibration(); |
155 | for (auto node : nodes) { |
156 | node_to_calibrator_.erase(node); |
157 | } |
158 | calibrator_registry_.erase(context); |
159 | } |
160 | |
161 | // Creates an instance of |Calibrator|. |
162 | // Registry owns the |Calibrator| object which can be deleted by calling |
163 | // |RemoveCalibrator|. |
164 | TfLiteStatus CreateCalibrator( |
165 | const TfLiteContext* context, |
166 | const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo, |
167 | std::unique_ptr<LoggingOpResolver> logging_op_resolver, |
168 | Calibrator** calibrator_ptr, ErrorReporter* reporter) { |
169 | if (calibrator_registry_.find(context) != calibrator_registry_.cend()) { |
170 | reporter->Report( |
171 | "Failed to create calibrator, context already registered." ); |
172 | return kTfLiteError; |
173 | } |
174 | auto calibrator = std::make_unique<Calibrator>( |
175 | node_to_opinfo, std::move(logging_op_resolver), reporter); |
176 | calibrator_registry_[context] = std::move(calibrator); |
177 | *calibrator_ptr = calibrator_registry_.at(context).get(); |
178 | for (const auto& entry : node_to_opinfo) { |
179 | node_to_calibrator_[entry.first] = *calibrator_ptr; |
180 | } |
181 | return kTfLiteOk; |
182 | } |
183 | |
184 | private: |
185 | absl::flat_hash_map<const TfLiteContext*, std::unique_ptr<Calibrator>> |
186 | calibrator_registry_; |
187 | absl::flat_hash_map<const TfLiteNode*, Calibrator*> node_to_calibrator_; |
188 | }; |
189 | |
190 | GlobalCalibratorRegistry* GetCalibratorRegistry() { |
191 | static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry(); |
192 | return registry; |
193 | } |
194 | |
195 | // Get the logging kernel if there are any. |
196 | // TODO(jianlijianli): extend this to support multiple recipe for the same |
197 | // model. |
198 | logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context, |
199 | TfLiteNode* node, |
200 | int builtin_op_code) { |
201 | switch (builtin_op_code) { |
202 | case BuiltinOperator_LSTM: { |
203 | if (node->intermediates->size == 12) { |
204 | return tflite::optimize::calibration::custom::lstm_logging_kernel; |
205 | } |
206 | return tflite::optimize::calibration::builtin::lstm_logging_kernel; |
207 | } |
208 | case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: |
209 | return tflite::optimize::calibration::builtin:: |
210 | unidirectional_sequence_lstm_logging_kernel; |
211 | default: |
212 | return nullptr; |
213 | } |
214 | } |
215 | |
216 | // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs, |
217 | // invokes the wrapped implementation and then logs the outputs. |
218 | TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) { |
219 | Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(node); |
220 | |
221 | if (!calibrator) { |
222 | TF_LITE_KERNEL_LOG(context, "No calibrator found for context." ); |
223 | return kTfLiteError; |
224 | } |
225 | |
226 | auto kernel_invoke = calibrator->GetKernelInvoke(node); |
227 | auto logger = calibrator->GetLogger(); |
228 | auto op_info = calibrator->GetOpInfo(node); |
229 | auto error_reporter = calibrator->GetErrorReporter(); |
230 | |
231 | for (int i : op_info.loggable_inputs) { |
232 | auto tensor = context->tensors[i]; |
233 | TF_LITE_ENSURE_STATUS( |
234 | logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f, |
235 | tensor.bytes / sizeof(float), error_reporter)); |
236 | } |
237 | auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code; |
238 | auto kernel_invoke_intermediate = |
239 | GetLoggingEvalFunc(context, node, builtin_op_code); |
240 | if (kernel_invoke_intermediate == nullptr) { |
241 | TF_LITE_ENSURE_STATUS(kernel_invoke(context, node)); |
242 | } else { |
243 | TF_LITE_ENSURE_STATUS( |
244 | kernel_invoke_intermediate(context, op_info.subgraph_index, node, |
245 | calibrator->GetLogger(), error_reporter)); |
246 | } |
247 | |
248 | // TODO(shashishekhar): An intermediate tensor in graph will get logged twice |
249 | // once as an input and second time as output. This doesn't change the min max |
250 | // values but is inefficient. |
251 | // Using moving average will also break this. |
252 | |
253 | // Log input again to make sure the state tensors are captured after lstm |
254 | // cell. |
255 | for (int i : op_info.loggable_inputs) { |
256 | auto tensor = context->tensors[i]; |
257 | TF_LITE_ENSURE_STATUS( |
258 | logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f, |
259 | tensor.bytes / sizeof(float), error_reporter)); |
260 | } |
261 | |
262 | for (int i : op_info.loggable_outputs) { |
263 | auto tensor = context->tensors[i]; |
264 | TF_LITE_ENSURE_STATUS( |
265 | logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f, |
266 | tensor.bytes / sizeof(float), error_reporter)); |
267 | } |
268 | |
269 | return kTfLiteOk; |
270 | } |
271 | |
272 | // Returns the loggable tensors. Not all inputs and outputs need to be logged. |
273 | // For example, const weight tensors which have buffers associated with them |
274 | // don't need to be logged. |
275 | std::vector<int> GetLoggableTensorIndices( |
276 | const std::vector<int>& tensor_indices, |
277 | const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, |
278 | const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) { |
279 | std::vector<int> loggable; |
280 | for (auto tensor_index : tensor_indices) { |
281 | if (tensor_index == kTfLiteOptionalTensor) { |
282 | continue; |
283 | } |
284 | auto tensor = tensors->Get(tensor_index); |
285 | auto buffer_index = tensor->buffer(); |
286 | const bool has_no_buffer = |
287 | (tensor_buffers->Get(buffer_index) == nullptr) || |
288 | (tensor_buffers->Get(buffer_index)->data() == nullptr) || |
289 | (tensor_buffers->Get(buffer_index)->data()->size() == 0); |
290 | if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) { |
291 | loggable.push_back(tensor_index); |
292 | } |
293 | } |
294 | return loggable; |
295 | } |
296 | |
297 | // Creates a mapping between the static model graph and the runtime TfLiteNode* |
298 | // nodes in the graph for the given context. |
299 | // This is done by querying the TfLiteContext for node and registrations using |
300 | // the |NodeInfoDelegateObserver|. |
301 | TfLiteStatus GetNodeOpInfoMapAndContext( |
302 | const absl::flat_hash_map<std::tuple<int, int>, OperatorInfo>& |
303 | node_to_opinfo, |
304 | tflite::Interpreter* const interpreter, |
305 | std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map, |
306 | TfLiteContext** context) { |
307 | *context = interpreter->primary_subgraph().context(); |
308 | |
309 | // Since we only consider the primary subgraph while populating |
310 | // node_to_opinfo, do the same here. |
311 | // Because Flex delegate can merge multiple op nodes into one Delegate node if |
312 | // they are located in a row, the size of the execution plan can be lesser |
313 | // than the size of the graph's op nodes. |
314 | TF_LITE_ENSURE(*context, |
315 | interpreter->execution_plan().size() <= node_to_opinfo.size()); |
316 | for (const auto& entry : node_to_opinfo) { |
317 | auto op_info = entry.second; |
318 | int subgraph_index, op_index; |
319 | std::tie(subgraph_index, op_index) = entry.first; |
320 | const auto* node_and_reg = |
321 | interpreter->node_and_registration(subgraph_index, op_index); |
322 | op_info.registration = &node_and_reg->second; |
323 | node_ptr_opinfo_map->insert({&node_and_reg->first, op_info}); |
324 | } |
325 | return kTfLiteOk; |
326 | } |
327 | |
328 | string GetOpName(const tflite::OperatorCode& opcode) { |
329 | if (opcode.custom_code() != nullptr) { |
330 | return opcode.custom_code()->str(); |
331 | } |
332 | return tflite::EnumNamesBuiltinOperator()[GetBuiltinCode(&opcode)]; |
333 | } |
334 | |
335 | // A |CalibrationReader| that owns the Calibrator. |
336 | class Reader : public CalibrationReader { |
337 | public: |
338 | Reader(const TfLiteContext* context, const Logger* logger) |
339 | : CalibrationReader(logger), context_(context) {} |
340 | |
341 | ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); } |
342 | |
343 | private: |
344 | const TfLiteContext* context_; |
345 | }; |
346 | |
347 | bool HasInputs(BuiltinOperator code) { |
348 | switch (code) { |
349 | case BuiltinOperator_CALL_ONCE: |
350 | case BuiltinOperator_VAR_HANDLE: |
351 | // Custom ops, including Flex ops, might not have inputs. |
352 | case BuiltinOperator_CUSTOM: |
353 | return false; |
354 | default: |
355 | return true; |
356 | } |
357 | } |
358 | |
359 | bool HasOutputs(BuiltinOperator code) { |
360 | switch (code) { |
361 | case BuiltinOperator_ASSIGN_VARIABLE: |
362 | case BuiltinOperator_CALL_ONCE: |
363 | // Custom ops, including Flex ops, might not have outputs. |
364 | case BuiltinOperator_CUSTOM: |
365 | return false; |
366 | default: |
367 | return true; |
368 | } |
369 | } |
370 | |
371 | } // namespace |
372 | |
373 | TfLiteStatus BuildLoggingInterpreter( |
374 | const FlatBufferModel& model, const OpResolver& op_resolver, |
375 | std::unique_ptr<Interpreter>* interpreter, |
376 | std::unique_ptr<CalibrationReader>* calibration_reader) { |
377 | return BuildLoggingInterpreter(model.GetModel(), model.error_reporter(), |
378 | op_resolver, interpreter, calibration_reader); |
379 | } |
380 | |
381 | TfLiteStatus BuildLoggingInterpreter( |
382 | const tflite::Model* tflite_model, ErrorReporter* error_reporter, |
383 | const OpResolver& op_resolver, std::unique_ptr<Interpreter>* interpreter, |
384 | std::unique_ptr<CalibrationReader>* calibration_reader) { |
385 | if (error_reporter == nullptr) { |
386 | // Make sure error_reporter is valid. |
387 | error_reporter = DefaultErrorReporter(); |
388 | } |
389 | auto subgraphs = tflite_model->subgraphs(); |
390 | auto tensor_buffers = tflite_model->buffers(); |
391 | |
392 | // Populate the node index to operator info map. |
393 | // We want to collect this information so we can use it during runtime to |
394 | // log details of which inputs and outputs. |
395 | // At runtime TFLite kernel invoke functions can only look into their |
396 | // own node in the graph (TFLiteNode*) and some limited context information. |
397 | absl::flat_hash_map<std::tuple<int, int>, OperatorInfo> node_to_opinfo; |
398 | BuiltinOpsSet builtin_op_and_versions; |
399 | CustomOpsSet custom_op_and_versions; |
400 | |
401 | for (size_t subgraph_index = 0; subgraph_index < subgraphs->size(); |
402 | subgraph_index++) { |
403 | auto subgraph = subgraphs->Get(subgraph_index); |
404 | auto operator_codes = tflite_model->operator_codes(); |
405 | auto operators = subgraph->operators(); |
406 | auto tensors = subgraph->tensors(); |
407 | if (!operators) { |
408 | continue; |
409 | } |
410 | |
411 | for (size_t i = 0; i < operators->size(); i++) { |
412 | OperatorInfo op_info; |
413 | op_info.subgraph_index = subgraph_index; |
414 | op_info.node_index = i; |
415 | auto op = operators->Get(i); |
416 | auto operator_code = operator_codes->Get(op->opcode_index()); |
417 | op_info.builtin_op_code = GetBuiltinCode(operator_code); |
418 | op_info.name = GetOpName(*operator_code); |
419 | op_info.is_custom_op = operator_code->custom_code() != nullptr; |
420 | op_info.version = operator_code->version(); |
421 | |
422 | auto op_inputs = op->inputs(); |
423 | auto op_outputs = op->outputs(); |
424 | if (op_inputs) { |
425 | op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end()); |
426 | } else if (HasInputs(op_info.builtin_op_code)) { |
427 | TFLITE_LOG(TFLITE_LOG_WARNING, "Op %s missing inputs" , |
428 | op_info.name.c_str()); |
429 | } |
430 | if (op_outputs) { |
431 | op_info.outputs = |
432 | std::vector<int>(op_outputs->begin(), op_outputs->end()); |
433 | } else if (HasOutputs(op_info.builtin_op_code)) { |
434 | TFLITE_LOG(TFLITE_LOG_WARNING, "Op %s missing outputs" , |
435 | op_info.name.c_str()); |
436 | } |
437 | op_info.loggable_inputs = |
438 | GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers); |
439 | op_info.loggable_outputs = |
440 | GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers); |
441 | if (op_info.is_custom_op) { |
442 | op_info.registration = |
443 | op_resolver.FindOp(op_info.name.c_str(), operator_code->version()); |
444 | custom_op_and_versions.insert( |
445 | {op_info.name.c_str(), operator_code->version()}); |
446 | } else { |
447 | op_info.registration = op_resolver.FindOp(GetBuiltinCode(operator_code), |
448 | operator_code->version()); |
449 | builtin_op_and_versions.insert( |
450 | {op_info.builtin_op_code, operator_code->version()}); |
451 | } |
452 | std::tuple<int, int> key{subgraph_index, i}; |
453 | node_to_opinfo[key] = op_info; |
454 | } |
455 | } |
456 | |
457 | // Prepare the logging op resolver to use |LoggingEval| for kernel |
458 | // invocations. |
459 | auto logging_op_resolver = std::make_unique<LoggingOpResolver>( |
460 | builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval, |
461 | error_reporter); |
462 | tflite::InterpreterBuilder(tflite_model, *logging_op_resolver, |
463 | error_reporter)(interpreter); |
464 | |
465 | if (!(*interpreter)) { |
466 | error_reporter->Report("Failed to construct interpreter" ); |
467 | return kTfLiteError; |
468 | } |
469 | |
470 | // Compute the mapping between runtime and static graph structure, i.e. |
471 | // (TfLiteContext, TfLiteNode) -> OperatorInfo |
472 | std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map; |
473 | TfLiteContext* context = nullptr; |
474 | TF_LITE_ENSURE_STATUS(GetNodeOpInfoMapAndContext( |
475 | node_to_opinfo, interpreter->get(), &node_ptr_opinfo_map, &context)); |
476 | |
477 | Calibrator* calibrator = nullptr; |
478 | // Register a calibrator object for the context. This can be accessed |
479 | // during invocations by the logging kernels. |
480 | TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator( |
481 | context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator, |
482 | error_reporter)); |
483 | *calibration_reader = std::unique_ptr<CalibrationReader>( |
484 | new Reader(context, calibrator->GetLogger())); |
485 | |
486 | return kTfLiteOk; |
487 | } |
488 | |
489 | } // namespace calibration |
490 | } // namespace optimize |
491 | } // namespace tflite |
492 | |