1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
16
17#include <fstream>
18#include <memory>
19#include <string>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "absl/container/flat_hash_map.h"
26#include "absl/memory/memory.h"
27#include "tensorflow/lite/c/common.h"
28#include "tensorflow/lite/core/api/error_reporter.h"
29#include "tensorflow/lite/core/api/op_resolver.h"
30#include "tensorflow/lite/interpreter.h"
31#include "tensorflow/lite/kernels/kernel_util.h"
32#include "tensorflow/lite/kernels/register.h"
33#include "tensorflow/lite/minimal_logging.h"
34#include "tensorflow/lite/model.h"
35#include "tensorflow/lite/op_resolver.h"
36#include "tensorflow/lite/schema/schema_generated.h"
37#include "tensorflow/lite/schema/schema_utils.h"
38#include "tensorflow/lite/stderr_reporter.h"
39#include "tensorflow/lite/string_util.h"
40#include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
41#include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
42#include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
43#include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
44#include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
45#include "tensorflow/lite/tools/optimize/calibration/logging_op.h"
46#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
47
48namespace tflite {
49namespace optimize {
50namespace calibration {
51
52namespace {
53
54// Calibrator is used to hold information that can be accessed during kernel
55// invocations.
56// TfLite kernel invocations are C functions and cannot look at the global
57// structure of the graph. Calibrator allows the kernel invoke functions to
58// access the global structure of graph and know which node is currently being
59// executed. This also allows us to write a simple kernel invoke wrapper
60// (see LoggingEval) that can work for most builtin ops.
61class Calibrator {
62 public:
63 Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
64 node_ptr_opinfo_map,
65 std::unique_ptr<LoggingOpResolver> logging_op_resolver,
66 ErrorReporter* error_reporter)
67 : node_ptr_opinfo_map_(node_ptr_opinfo_map),
68 logging_op_resolver_(std::move(logging_op_resolver)),
69 error_reporter_(error_reporter) {
70 logger_ = std::make_unique<Logger>();
71 }
72
73 // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
74 KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
75
76 // Gets the instance of logger associated with the current context.
77 Logger* GetLogger() const { return logger_.get(); }
78
79 // Gets the error reporter.
80 ErrorReporter* GetErrorReporter() const { return error_reporter_; }
81
82 // Gets the operator information about the given TfLiteNode.
83 const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
84 return node_ptr_opinfo_map_.at(node);
85 }
86
87 std::vector<const TfLiteNode*> GetNodesUnderCalibration() {
88 std::vector<const TfLiteNode*> nodes;
89 nodes.reserve(node_ptr_opinfo_map_.size());
90 for (const auto& entry : node_ptr_opinfo_map_) {
91 nodes.push_back(entry.first);
92 }
93 return nodes;
94 }
95
96 private:
97 std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
98 std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
99 const std::unordered_map<int, OperatorInfo> index_opinfo_;
100 std::unique_ptr<Logger> logger_;
101 ErrorReporter* error_reporter_;
102};
103
104KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
105 auto op_info = node_ptr_opinfo_map_.at(node);
106 if (op_info.is_custom_op) {
107 return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
108 op_info.version);
109 }
110 return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
111 op_info.version);
112}
113
114// A registry of |Calibrator| objects per |TfLiteContext|.
115// This global registry is needed to access |Calibrator| objects in the kernel
116// invoke functions i.e. |TfLiteRegistration.invoke|.
117// Kernel invoke functions are C functions that have limited access to
118// |TfLiteContext|. Kernel invoke functions don't have access to global state of
119// graph. That means during a kernel invocation, the function cannot know which
120// node it was invoked for. E.g. in case of a model with |Conv| op at two
121// locations, there is no easy way for the Conv.invoke function to disambiguate
122// the calls.
123//
124// For calibration we solve this problem by creating a map of calibrators
125// per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
126//
127// This registry is then accessed using a global getter function:
128// |GetCalibratorRegistry|.
129// E.g.
130// TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
131// .... code ....
132// auto registry = GetCalibratorRegistry();
133// auto calibrator = registry->GetCalibrator(context);
134// ..... code ....
135// }
136//
137// This way the kernel invoke functions can get the access to the Calibrator
138// object associated with the |TfLiteContext|.
139class GlobalCalibratorRegistry {
140 public:
141 // Get the |Calibrator| associated with given context, returns null if no
142 // calibrator is associated with the given context.
143 Calibrator* GetCalibrator(const TfLiteNode* node) const {
144 if (node_to_calibrator_.find(node) == node_to_calibrator_.cend()) {
145 return nullptr;
146 }
147 return node_to_calibrator_.at(node);
148 }
149
150 // Removes the association between calibrator and context.
151 // Note: This deletes the calibrator as well.
152 void RemoveCalibrator(const TfLiteContext* context) {
153 Calibrator* calibrator = calibrator_registry_.at(context).get();
154 auto nodes = calibrator->GetNodesUnderCalibration();
155 for (auto node : nodes) {
156 node_to_calibrator_.erase(node);
157 }
158 calibrator_registry_.erase(context);
159 }
160
161 // Creates an instance of |Calibrator|.
162 // Registry owns the |Calibrator| object which can be deleted by calling
163 // |RemoveCalibrator|.
164 TfLiteStatus CreateCalibrator(
165 const TfLiteContext* context,
166 const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
167 std::unique_ptr<LoggingOpResolver> logging_op_resolver,
168 Calibrator** calibrator_ptr, ErrorReporter* reporter) {
169 if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
170 reporter->Report(
171 "Failed to create calibrator, context already registered.");
172 return kTfLiteError;
173 }
174 auto calibrator = std::make_unique<Calibrator>(
175 node_to_opinfo, std::move(logging_op_resolver), reporter);
176 calibrator_registry_[context] = std::move(calibrator);
177 *calibrator_ptr = calibrator_registry_.at(context).get();
178 for (const auto& entry : node_to_opinfo) {
179 node_to_calibrator_[entry.first] = *calibrator_ptr;
180 }
181 return kTfLiteOk;
182 }
183
184 private:
185 absl::flat_hash_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
186 calibrator_registry_;
187 absl::flat_hash_map<const TfLiteNode*, Calibrator*> node_to_calibrator_;
188};
189
190GlobalCalibratorRegistry* GetCalibratorRegistry() {
191 static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
192 return registry;
193}
194
195// Get the logging kernel if there are any.
196// TODO(jianlijianli): extend this to support multiple recipe for the same
197// model.
198logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
199 TfLiteNode* node,
200 int builtin_op_code) {
201 switch (builtin_op_code) {
202 case BuiltinOperator_LSTM: {
203 if (node->intermediates->size == 12) {
204 return tflite::optimize::calibration::custom::lstm_logging_kernel;
205 }
206 return tflite::optimize::calibration::builtin::lstm_logging_kernel;
207 }
208 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
209 return tflite::optimize::calibration::builtin::
210 unidirectional_sequence_lstm_logging_kernel;
211 default:
212 return nullptr;
213 }
214}
215
216// A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
217// invokes the wrapped implementation and then logs the outputs.
218TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
219 Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(node);
220
221 if (!calibrator) {
222 TF_LITE_KERNEL_LOG(context, "No calibrator found for context.");
223 return kTfLiteError;
224 }
225
226 auto kernel_invoke = calibrator->GetKernelInvoke(node);
227 auto logger = calibrator->GetLogger();
228 auto op_info = calibrator->GetOpInfo(node);
229 auto error_reporter = calibrator->GetErrorReporter();
230
231 for (int i : op_info.loggable_inputs) {
232 auto tensor = context->tensors[i];
233 TF_LITE_ENSURE_STATUS(
234 logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
235 tensor.bytes / sizeof(float), error_reporter));
236 }
237 auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
238 auto kernel_invoke_intermediate =
239 GetLoggingEvalFunc(context, node, builtin_op_code);
240 if (kernel_invoke_intermediate == nullptr) {
241 TF_LITE_ENSURE_STATUS(kernel_invoke(context, node));
242 } else {
243 TF_LITE_ENSURE_STATUS(
244 kernel_invoke_intermediate(context, op_info.subgraph_index, node,
245 calibrator->GetLogger(), error_reporter));
246 }
247
248 // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
249 // once as an input and second time as output. This doesn't change the min max
250 // values but is inefficient.
251 // Using moving average will also break this.
252
253 // Log input again to make sure the state tensors are captured after lstm
254 // cell.
255 for (int i : op_info.loggable_inputs) {
256 auto tensor = context->tensors[i];
257 TF_LITE_ENSURE_STATUS(
258 logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
259 tensor.bytes / sizeof(float), error_reporter));
260 }
261
262 for (int i : op_info.loggable_outputs) {
263 auto tensor = context->tensors[i];
264 TF_LITE_ENSURE_STATUS(
265 logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
266 tensor.bytes / sizeof(float), error_reporter));
267 }
268
269 return kTfLiteOk;
270}
271
272// Returns the loggable tensors. Not all inputs and outputs need to be logged.
273// For example, const weight tensors which have buffers associated with them
274// don't need to be logged.
275std::vector<int> GetLoggableTensorIndices(
276 const std::vector<int>& tensor_indices,
277 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
278 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
279 std::vector<int> loggable;
280 for (auto tensor_index : tensor_indices) {
281 if (tensor_index == kTfLiteOptionalTensor) {
282 continue;
283 }
284 auto tensor = tensors->Get(tensor_index);
285 auto buffer_index = tensor->buffer();
286 const bool has_no_buffer =
287 (tensor_buffers->Get(buffer_index) == nullptr) ||
288 (tensor_buffers->Get(buffer_index)->data() == nullptr) ||
289 (tensor_buffers->Get(buffer_index)->data()->size() == 0);
290 if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
291 loggable.push_back(tensor_index);
292 }
293 }
294 return loggable;
295}
296
297// Creates a mapping between the static model graph and the runtime TfLiteNode*
298// nodes in the graph for the given context.
299// This is done by querying the TfLiteContext for node and registrations using
300// the |NodeInfoDelegateObserver|.
301TfLiteStatus GetNodeOpInfoMapAndContext(
302 const absl::flat_hash_map<std::tuple<int, int>, OperatorInfo>&
303 node_to_opinfo,
304 tflite::Interpreter* const interpreter,
305 std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
306 TfLiteContext** context) {
307 *context = interpreter->primary_subgraph().context();
308
309 // Since we only consider the primary subgraph while populating
310 // node_to_opinfo, do the same here.
311 // Because Flex delegate can merge multiple op nodes into one Delegate node if
312 // they are located in a row, the size of the execution plan can be lesser
313 // than the size of the graph's op nodes.
314 TF_LITE_ENSURE(*context,
315 interpreter->execution_plan().size() <= node_to_opinfo.size());
316 for (const auto& entry : node_to_opinfo) {
317 auto op_info = entry.second;
318 int subgraph_index, op_index;
319 std::tie(subgraph_index, op_index) = entry.first;
320 const auto* node_and_reg =
321 interpreter->node_and_registration(subgraph_index, op_index);
322 op_info.registration = &node_and_reg->second;
323 node_ptr_opinfo_map->insert({&node_and_reg->first, op_info});
324 }
325 return kTfLiteOk;
326}
327
328string GetOpName(const tflite::OperatorCode& opcode) {
329 if (opcode.custom_code() != nullptr) {
330 return opcode.custom_code()->str();
331 }
332 return tflite::EnumNamesBuiltinOperator()[GetBuiltinCode(&opcode)];
333}
334
335// A |CalibrationReader| that owns the Calibrator.
336class Reader : public CalibrationReader {
337 public:
338 Reader(const TfLiteContext* context, const Logger* logger)
339 : CalibrationReader(logger), context_(context) {}
340
341 ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
342
343 private:
344 const TfLiteContext* context_;
345};
346
347bool HasInputs(BuiltinOperator code) {
348 switch (code) {
349 case BuiltinOperator_CALL_ONCE:
350 case BuiltinOperator_VAR_HANDLE:
351 // Custom ops, including Flex ops, might not have inputs.
352 case BuiltinOperator_CUSTOM:
353 return false;
354 default:
355 return true;
356 }
357}
358
359bool HasOutputs(BuiltinOperator code) {
360 switch (code) {
361 case BuiltinOperator_ASSIGN_VARIABLE:
362 case BuiltinOperator_CALL_ONCE:
363 // Custom ops, including Flex ops, might not have outputs.
364 case BuiltinOperator_CUSTOM:
365 return false;
366 default:
367 return true;
368 }
369}
370
371} // namespace
372
373TfLiteStatus BuildLoggingInterpreter(
374 const FlatBufferModel& model, const OpResolver& op_resolver,
375 std::unique_ptr<Interpreter>* interpreter,
376 std::unique_ptr<CalibrationReader>* calibration_reader) {
377 return BuildLoggingInterpreter(model.GetModel(), model.error_reporter(),
378 op_resolver, interpreter, calibration_reader);
379}
380
381TfLiteStatus BuildLoggingInterpreter(
382 const tflite::Model* tflite_model, ErrorReporter* error_reporter,
383 const OpResolver& op_resolver, std::unique_ptr<Interpreter>* interpreter,
384 std::unique_ptr<CalibrationReader>* calibration_reader) {
385 if (error_reporter == nullptr) {
386 // Make sure error_reporter is valid.
387 error_reporter = DefaultErrorReporter();
388 }
389 auto subgraphs = tflite_model->subgraphs();
390 auto tensor_buffers = tflite_model->buffers();
391
392 // Populate the node index to operator info map.
393 // We want to collect this information so we can use it during runtime to
394 // log details of which inputs and outputs.
395 // At runtime TFLite kernel invoke functions can only look into their
396 // own node in the graph (TFLiteNode*) and some limited context information.
397 absl::flat_hash_map<std::tuple<int, int>, OperatorInfo> node_to_opinfo;
398 BuiltinOpsSet builtin_op_and_versions;
399 CustomOpsSet custom_op_and_versions;
400
401 for (size_t subgraph_index = 0; subgraph_index < subgraphs->size();
402 subgraph_index++) {
403 auto subgraph = subgraphs->Get(subgraph_index);
404 auto operator_codes = tflite_model->operator_codes();
405 auto operators = subgraph->operators();
406 auto tensors = subgraph->tensors();
407 if (!operators) {
408 continue;
409 }
410
411 for (size_t i = 0; i < operators->size(); i++) {
412 OperatorInfo op_info;
413 op_info.subgraph_index = subgraph_index;
414 op_info.node_index = i;
415 auto op = operators->Get(i);
416 auto operator_code = operator_codes->Get(op->opcode_index());
417 op_info.builtin_op_code = GetBuiltinCode(operator_code);
418 op_info.name = GetOpName(*operator_code);
419 op_info.is_custom_op = operator_code->custom_code() != nullptr;
420 op_info.version = operator_code->version();
421
422 auto op_inputs = op->inputs();
423 auto op_outputs = op->outputs();
424 if (op_inputs) {
425 op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
426 } else if (HasInputs(op_info.builtin_op_code)) {
427 TFLITE_LOG(TFLITE_LOG_WARNING, "Op %s missing inputs",
428 op_info.name.c_str());
429 }
430 if (op_outputs) {
431 op_info.outputs =
432 std::vector<int>(op_outputs->begin(), op_outputs->end());
433 } else if (HasOutputs(op_info.builtin_op_code)) {
434 TFLITE_LOG(TFLITE_LOG_WARNING, "Op %s missing outputs",
435 op_info.name.c_str());
436 }
437 op_info.loggable_inputs =
438 GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
439 op_info.loggable_outputs =
440 GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
441 if (op_info.is_custom_op) {
442 op_info.registration =
443 op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
444 custom_op_and_versions.insert(
445 {op_info.name.c_str(), operator_code->version()});
446 } else {
447 op_info.registration = op_resolver.FindOp(GetBuiltinCode(operator_code),
448 operator_code->version());
449 builtin_op_and_versions.insert(
450 {op_info.builtin_op_code, operator_code->version()});
451 }
452 std::tuple<int, int> key{subgraph_index, i};
453 node_to_opinfo[key] = op_info;
454 }
455 }
456
457 // Prepare the logging op resolver to use |LoggingEval| for kernel
458 // invocations.
459 auto logging_op_resolver = std::make_unique<LoggingOpResolver>(
460 builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
461 error_reporter);
462 tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
463 error_reporter)(interpreter);
464
465 if (!(*interpreter)) {
466 error_reporter->Report("Failed to construct interpreter");
467 return kTfLiteError;
468 }
469
470 // Compute the mapping between runtime and static graph structure, i.e.
471 // (TfLiteContext, TfLiteNode) -> OperatorInfo
472 std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
473 TfLiteContext* context = nullptr;
474 TF_LITE_ENSURE_STATUS(GetNodeOpInfoMapAndContext(
475 node_to_opinfo, interpreter->get(), &node_ptr_opinfo_map, &context));
476
477 Calibrator* calibrator = nullptr;
478 // Register a calibrator object for the context. This can be accessed
479 // during invocations by the logging kernels.
480 TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
481 context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
482 error_reporter));
483 *calibration_reader = std::unique_ptr<CalibrationReader>(
484 new Reader(context, calibrator->GetLogger()));
485
486 return kTfLiteOk;
487}
488
489} // namespace calibration
490} // namespace optimize
491} // namespace tflite
492